mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 04:55:00 +08:00
Compare commits
15 Commits
ciflow/roc
...
cpp-docs-d
| Author | SHA1 | Date | |
|---|---|---|---|
| 2913cdf29d | |||
| 0661a232a5 | |||
| 5db844dafa | |||
| 73efad99d7 | |||
| df1268c311 | |||
| 84f9f1541d | |||
| 27c0c126bf | |||
| 670873155a | |||
| 923737c510 | |||
| 13d5b14a73 | |||
| a35a42b21c | |||
| 15956bc1e8 | |||
| b319ea1111 | |||
| ce4c68a5f6 | |||
| c6da4a59a3 |
@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ARG DEVTOOLSET_VERSION=11
|
||||
|
||||
RUN yum -y update
|
||||
RUN yum -y install epel-release
|
||||
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
||||
RUN yum -y install glibc-langpack-en
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
|
||||
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
@ -41,7 +41,6 @@ RUN bash ./install_conda.sh && rm install_conda.sh
|
||||
# Install CUDA
|
||||
FROM base as cuda
|
||||
ARG CUDA_VERSION=12.6
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
RUN rm -rf /usr/local/cuda-*
|
||||
ADD ./common/install_cuda.sh install_cuda.sh
|
||||
COPY ./common/install_nccl.sh install_nccl.sh
|
||||
@ -51,8 +50,7 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
|
||||
# Preserve CUDA_VERSION for the builds
|
||||
ENV CUDA_VERSION=${CUDA_VERSION}
|
||||
# Make things in our path by default
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
|
||||
|
||||
FROM cuda as cuda12.6
|
||||
RUN bash ./install_cuda.sh 12.6
|
||||
@ -70,22 +68,8 @@ FROM cuda as cuda13.0
|
||||
RUN bash ./install_cuda.sh 13.0
|
||||
ENV DESIRED_CUDA=13.0
|
||||
|
||||
FROM ${ROCM_IMAGE} as rocm_base
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV LC_ALL en_US.UTF-8
|
||||
ENV LANG en_US.UTF-8
|
||||
ENV LANGUAGE en_US.UTF-8
|
||||
# Install devtoolset on ROCm base image
|
||||
RUN yum -y update && \
|
||||
yum -y install epel-release && \
|
||||
yum -y install glibc-langpack-en && \
|
||||
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
||||
RUN git config --global --add safe.directory '*'
|
||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
|
||||
FROM rocm_base as rocm
|
||||
FROM ${ROCM_IMAGE} as rocm
|
||||
ARG PYTORCH_ROCM_ARCH
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||
ADD ./common/install_mkl.sh install_mkl.sh
|
||||
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
||||
@ -104,7 +88,6 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
|
||||
|
||||
# Final step
|
||||
FROM ${BASE_TARGET} as final
|
||||
ARG DEVTOOLSET_VERSION=13
|
||||
COPY --from=openssl /opt/openssl /opt/openssl
|
||||
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
||||
COPY --from=conda /opt/conda /opt/conda
|
||||
|
||||
@ -36,7 +36,11 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
;;
|
||||
rocm*)
|
||||
BASE_TARGET=rocm
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
||||
;;
|
||||
*)
|
||||
@ -59,7 +63,7 @@ docker build \
|
||||
--target final \
|
||||
--progress plain \
|
||||
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
||||
--build-arg "DEVTOOLSET_VERSION=13" \
|
||||
--build-arg "DEVTOOLSET_VERSION=11" \
|
||||
${EXTRA_BUILD_ARGS} \
|
||||
-t ${tmp_tag} \
|
||||
$@ \
|
||||
|
||||
@ -168,18 +168,6 @@ case "$tag" in
|
||||
VISION=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.11-clang12)
|
||||
ANACONDA_PYTHON_VERSION=3.11
|
||||
CLANG_VERSION=12
|
||||
VISION=no
|
||||
TRITON=no
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-clang12)
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
CLANG_VERSION=12
|
||||
VISION=no
|
||||
TRITON=no
|
||||
;;
|
||||
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
|
||||
if [[ $tag =~ "jammy" ]]; then
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
@ -207,9 +195,9 @@ case "$tag" in
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
|
||||
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=13
|
||||
GCC_VERSION=11
|
||||
VISION=yes
|
||||
XPU_VERSION=2025.2
|
||||
NINJA_VERSION=1.9.0
|
||||
@ -260,12 +248,6 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-triton-cpu)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
@ -279,9 +261,9 @@ case "$tag" in
|
||||
PYTHON_VERSION=3.10
|
||||
CUDA_VERSION=12.8.1
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc13)
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc11)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=13
|
||||
GCC_VERSION=11
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
@ -289,19 +271,9 @@ case "$tag" in
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-clang21)
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
CLANG_VERSION=21
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
# snadampal: skipping llvm src build install because the current version
|
||||
# from pytorch/llvm:9.0.1 is x86 specific
|
||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||
;;
|
||||
pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=13
|
||||
GCC_VERSION=11
|
||||
ACL=yes
|
||||
VISION=yes
|
||||
OPENBLAS=yes
|
||||
@ -387,7 +359,6 @@ docker build \
|
||||
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
|
||||
--build-arg "EXECUTORCH=${EXECUTORCH}" \
|
||||
--build-arg "HALIDE=${HALIDE}" \
|
||||
--build-arg "PALLAS=${PALLAS}" \
|
||||
--build-arg "XPU_VERSION=${XPU_VERSION}" \
|
||||
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
|
||||
--build-arg "ACL=${ACL:-}" \
|
||||
|
||||
@ -1 +0,0 @@
|
||||
0.8.0
|
||||
@ -1 +1 @@
|
||||
0add68262ab0a2e33b84524346cb27cbb2787356
|
||||
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
|
||||
|
||||
@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
|
||||
# work around ubuntu apt-get conflicts
|
||||
sudo apt-get -y -f install
|
||||
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
||||
if [[ $CLANG_VERSION -ge 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
|
||||
if [[ $CLANG_VERSION == 18 ]]; then
|
||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then
|
||||
# Need the official toolchain repo to get alternate packages
|
||||
add-apt-repository ppa:ubuntu-toolchain-r/test
|
||||
apt-get update
|
||||
apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION
|
||||
apt-get install -y g++-$GCC_VERSION
|
||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
|
||||
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
|
||||
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
|
||||
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50
|
||||
|
||||
|
||||
# Cleanup package manager
|
||||
apt-get autoclean && apt-get clean
|
||||
|
||||
@ -1,40 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
# Get the pinned JAX version (same for all CUDA versions)
|
||||
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
|
||||
|
||||
function install_jax_12() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
|
||||
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
|
||||
}
|
||||
|
||||
function install_jax_13() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
|
||||
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
|
||||
}
|
||||
|
||||
# idiomatic parameter and option handling in sh
|
||||
while test $# -gt 0
|
||||
do
|
||||
case "$1" in
|
||||
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
|
||||
;;
|
||||
13.0|13.0.*) install_jax_13;
|
||||
;;
|
||||
*) echo "bad argument $1"; exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
@ -1,56 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Script used only in CD pipeline
|
||||
|
||||
set -ex
|
||||
|
||||
# install dependencies
|
||||
dnf -y install gmp-devel libmpc-devel texinfo flex bison
|
||||
|
||||
cd /usr/local/src
|
||||
# fetch source for gcc 13
|
||||
git clone --depth 1 --single-branch -b releases/gcc-13.3.0 https://github.com/gcc-mirror/gcc.git gcc-13.3.0
|
||||
|
||||
mkdir -p gcc-13.3.0/build-gomp
|
||||
cd gcc-13.3.0/build-gomp
|
||||
|
||||
# configure gcc build
|
||||
# I got these flags by:
|
||||
# 1. downloading the source rpm for gcc-11 on AlmaLinux 8 container
|
||||
# dnf install -y dnf-plugins-core rpmdevtools
|
||||
# dnf download --source libgomp
|
||||
# 2. extracting the gcc.spec from the source.
|
||||
# rpmdev-extract gcc-xx.src.rpm
|
||||
# 3. extracting optflags and ld_flags from gcc.spec:
|
||||
# rpm --eval '%{optflags}'
|
||||
# rpm --eval '%{build_ldflags}'
|
||||
#
|
||||
# I had to remove the following flags because they didn't compile for this version of libgomp:
|
||||
# -Werror=format-security
|
||||
# -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1
|
||||
# -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1
|
||||
#
|
||||
# I added -march=armv8-a -mtune=generic to make them explicit. I don't think they're strictly needed.
|
||||
|
||||
OPT_FLAGS='-O2 -march=armv8-a -mtune=generic'\
|
||||
' -fexceptions -g -grecord-gcc-switches -pipe -Wall'\
|
||||
' -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS'\
|
||||
' -fstack-protector-strong -fasynchronous-unwind-tables'\
|
||||
' -fstack-clash-protection'
|
||||
|
||||
LDFLAGS='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now'
|
||||
|
||||
CFLAGS="$OPT_FLAGS" \
|
||||
CXXFLAGS="$OPT_FLAGS" \
|
||||
LDFLAGS="$LDFLAGS" \
|
||||
../configure \
|
||||
--prefix=/usr \
|
||||
--libdir=/usr/lib64 \
|
||||
--enable-languages=c,c++ \
|
||||
--disable-multilib \
|
||||
--disable-bootstrap \
|
||||
--enable-libgomp
|
||||
|
||||
# only build libgomp
|
||||
make -j$(nproc) all-target-libgomp
|
||||
|
||||
make install-target-libgomp
|
||||
@ -10,7 +10,6 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
|
||||
|
||||
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
||||
OPENBLAS_BUILD_FLAGS="
|
||||
CC=gcc
|
||||
NUM_THREADS=128
|
||||
USE_OPENMP=1
|
||||
NO_SHARED=0
|
||||
|
||||
@ -9,7 +9,7 @@ set -xe
|
||||
|
||||
function install_ubuntu() {
|
||||
. /etc/os-release
|
||||
if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then
|
||||
if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then
|
||||
echo "Ubuntu version ${VERSION_CODENAME} not supported"
|
||||
exit
|
||||
fi
|
||||
@ -35,24 +35,25 @@ function install_ubuntu() {
|
||||
# The xpu-smi packages
|
||||
apt-get install -y flex bison xpu-smi
|
||||
|
||||
# Compute and Media Runtimes
|
||||
if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then
|
||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||
# Compute and Media Runtimes
|
||||
apt-get install -y \
|
||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
intel-opencl-icd intel-level-zero-gpu level-zero \
|
||||
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||
else # jammy
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
|
||||
# Development Packages
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
|
||||
else # rolling driver
|
||||
apt-get install -y \
|
||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||
fi
|
||||
# Development Packages
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||
|
||||
# Install Intel Support Packages
|
||||
apt-get install -y ${XPU_PACKAGES}
|
||||
@ -65,7 +66,7 @@ function install_ubuntu() {
|
||||
function install_rhel() {
|
||||
. /etc/os-release
|
||||
if [[ "${ID}" == "rhel" ]]; then
|
||||
if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||
if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||
echo "RHEL version ${VERSION_ID} not supported"
|
||||
exit
|
||||
fi
|
||||
@ -146,7 +147,7 @@ function install_sles() {
|
||||
XPU_DRIVER_VERSION=""
|
||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||
# Use GPU driver LTS releases
|
||||
XPU_DRIVER_VERSION="/lts/2523"
|
||||
XPU_DRIVER_VERSION="/lts/2350"
|
||||
fi
|
||||
|
||||
# Default use Intel® oneAPI Deep Learning Essentials 2025.1
|
||||
|
||||
@ -49,7 +49,11 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
fi
|
||||
BASE_TARGET=rocm
|
||||
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -50,10 +50,6 @@ RUN rm install_ninja.sh
|
||||
ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Build a newer version of libgomp than that supported in in Almalinux 8.
|
||||
COPY ./common/install_libgomp.sh install_libgomp.sh
|
||||
RUN bash ./install_libgomp.sh && rm install_libgomp.sh
|
||||
|
||||
# git236+ would refuse to run git commands in repos owned by other users
|
||||
# Which causes version check to fail, as pytorch repo is bind-mounted into the image
|
||||
# Override this behaviour by treating every folder as safe
|
||||
|
||||
@ -87,7 +87,11 @@ case ${image} in
|
||||
MANY_LINUX_VERSION="2_28"
|
||||
DEVTOOLSET_VERSION="11"
|
||||
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
|
||||
;;
|
||||
manylinux2_28-builder:xpu)
|
||||
|
||||
@ -1 +1 @@
|
||||
3.5.1
|
||||
3.5.0
|
||||
|
||||
@ -143,15 +143,6 @@ COPY ci_commit_pins/halide.txt halide.txt
|
||||
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
|
||||
RUN rm install_halide.sh common_utils.sh halide.txt
|
||||
|
||||
ARG PALLAS
|
||||
ARG CUDA_VERSION
|
||||
# Install JAX with CUDA support (for Pallas)
|
||||
COPY ./common/install_jax.sh install_jax.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
|
||||
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
|
||||
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
|
||||
|
||||
ARG ONNX
|
||||
# Install ONNX dependencies
|
||||
COPY ./common/install_onnx.sh ./common/common_utils.sh ./
|
||||
|
||||
@ -8,11 +8,9 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
try:
|
||||
from collections.abc import Callable # Python 3.11+
|
||||
from typing import Any, Required, TypedDict
|
||||
from typing import Any, Callable, Required, TypedDict # Python 3.11+
|
||||
except ImportError:
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Callable, TypedDict
|
||||
|
||||
from typing_extensions import Required # Fallback for Python <3.11
|
||||
|
||||
|
||||
@ -6,8 +6,8 @@ set -eou pipefail
|
||||
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
|
||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
|
||||
# https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
|
||||
# post merge of https://github.com/icl-utk-edu/magma/pull/65
|
||||
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
|
||||
|
||||
# Folders for the build
|
||||
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
|
||||
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
|
||||
|
||||
# Fetch magma sources and verify checksum
|
||||
pushd ${PACKAGE_DIR}
|
||||
git clone https://github.com/jeffdaily/magma
|
||||
git clone https://github.com/icl-utk-edu/magma
|
||||
pushd magma
|
||||
git checkout ${MAGMA_VERSION}
|
||||
popd
|
||||
|
||||
@ -168,16 +168,14 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/umf/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Enable XCCL build
|
||||
export USE_XCCL=1
|
||||
export USE_MPI=0
|
||||
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
|
||||
export USE_KINETO=0
|
||||
export TORCH_XPU_ARCH_LIST=pvc
|
||||
fi
|
||||
|
||||
|
||||
@ -208,8 +208,6 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Check XPU status before testing
|
||||
timeout 30 xpu-smi discovery || true
|
||||
fi
|
||||
@ -826,11 +824,6 @@ test_inductor_halide() {
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_pallas() {
|
||||
python test/run_test.py --include inductor/test_pallas.py --verbose
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_triton_cpu() {
|
||||
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
|
||||
assert_git_not_dirty
|
||||
@ -1731,8 +1724,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
test_inductor_halide
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
|
||||
test_inductor_pallas
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||
test_inductor_triton_cpu
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||
|
||||
@ -70,7 +70,7 @@ sccache --zero-stats
|
||||
sccache --show-stats
|
||||
|
||||
# Build the wheel
|
||||
python -m build --wheel --no-isolation
|
||||
python -m build --wheel --no-build-isolation
|
||||
if ($LASTEXITCODE -ne 0) { exit 1 }
|
||||
|
||||
# Install the wheel locally
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
name: 🚀 New Feature for Release
|
||||
name: 🚀 Release highlight for proposed Feature
|
||||
description: Submit a Release highlight for proposed Feature
|
||||
labels: ["release-feature-request"]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: New Feature for Release
|
||||
label: Release highlight for proposed Feature
|
||||
description: >
|
||||
Example: “A torch.special module, analogous to SciPy's special module.”
|
||||
- type: input
|
||||
|
||||
12
.github/actions/pytest-cache-download/action.yml
vendored
12
.github/actions/pytest-cache-download/action.yml
vendored
@ -38,9 +38,9 @@ runs:
|
||||
run: |
|
||||
python3 .github/scripts/pytest_cache.py \
|
||||
--download \
|
||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
||||
--pr_identifier "$GITHUB_REF" \
|
||||
--job_identifier "$JOB_IDENTIFIER" \
|
||||
--temp_dir "$RUNNER_TEMP" \
|
||||
--repo "$REPO" \
|
||||
--bucket "$BUCKET" \
|
||||
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||
--pr_identifier $GITHUB_REF \
|
||||
--job_identifier $JOB_IDENTIFIER \
|
||||
--temp_dir $RUNNER_TEMP \
|
||||
--repo $REPO \
|
||||
--bucket $BUCKET \
|
||||
|
||||
16
.github/actions/pytest-cache-upload/action.yml
vendored
16
.github/actions/pytest-cache-upload/action.yml
vendored
@ -47,11 +47,11 @@ runs:
|
||||
run: |
|
||||
python3 .github/scripts/pytest_cache.py \
|
||||
--upload \
|
||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
||||
--pr_identifier "$GITHUB_REF" \
|
||||
--job_identifier "$JOB_IDENTIFIER" \
|
||||
--sha "$SHA" \
|
||||
--test_config "$TEST_CONFIG" \
|
||||
--shard "$SHARD" \
|
||||
--repo "$REPO" \
|
||||
--temp_dir "$RUNNER_TEMP" \
|
||||
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||
--pr_identifier $GITHUB_REF \
|
||||
--job_identifier $JOB_IDENTIFIER \
|
||||
--sha $SHA \
|
||||
--test_config $TEST_CONFIG \
|
||||
--shard $SHARD \
|
||||
--repo $REPO \
|
||||
--temp_dir $RUNNER_TEMP \
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
ad5816f0eee1c873df1b7d371c69f1f811a89387
|
||||
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
ccb801b88af136454798b945175c4c87e636ac33
|
||||
cfbc5c2f1c798991715a6b06bb3ce46478c4487c
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
||||
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
|
||||
|
||||
125
.github/copilot-instructions.md
vendored
125
.github/copilot-instructions.md
vendored
@ -1,125 +0,0 @@
|
||||
# PyTorch Copilot Instructions
|
||||
|
||||
This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Core Components
|
||||
|
||||
- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality
|
||||
- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd
|
||||
- `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse)
|
||||
- `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry
|
||||
- **torch/** - Python bindings and public API
|
||||
- `torch/csrc/` - C++ Python bindings (hand-written and generated)
|
||||
- `torch/csrc/autograd/` - Reverse-mode automatic differentiation
|
||||
- `torch/csrc/jit/` - TorchScript JIT compiler
|
||||
- **torchgen/** - Code generation tooling that reads `native_functions.yaml`
|
||||
- **tools/** - Build scripts, autograd derivatives, code generation
|
||||
|
||||
### The Code Generation Workflow
|
||||
|
||||
**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file:
|
||||
1. Declares operator signatures, variants (function/method), and dispatch behavior
|
||||
2. Gets processed by `torchgen/` to generate C++/Python bindings
|
||||
3. Produces headers in `build/aten/src/ATen/` during compilation
|
||||
|
||||
Example entry structure:
|
||||
```yaml
|
||||
- func: my_op(Tensor self, Scalar alpha=1) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: my_op_cpu
|
||||
CUDA: my_op_cuda
|
||||
```
|
||||
|
||||
After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`).
|
||||
|
||||
## Development Workflows
|
||||
|
||||
### Building from Source
|
||||
|
||||
**Never run `setup.py` directly** - use pip with editable install:
|
||||
```bash
|
||||
python -m pip install --no-build-isolation -v -e .
|
||||
```
|
||||
|
||||
Speed up builds:
|
||||
- `DEBUG=1` - Debug symbols with `-g -O0`
|
||||
- `USE_CUDA=0` - Skip CUDA compilation
|
||||
- `BUILD_TEST=0` - Skip C++ test binaries
|
||||
- Install `ninja` (`pip install ninja`) for faster builds
|
||||
- Use `ccache` for incremental compilation caching
|
||||
|
||||
Rebuild specific targets: `(cd build && ninja <target>)`
|
||||
|
||||
### Testing
|
||||
|
||||
**Critical**: DO NOT run entire test suites. Run specific tests only:
|
||||
```bash
|
||||
python test/test_torch.py TestTorch.test_specific_case
|
||||
```
|
||||
|
||||
**Test structure**: All tests use `torch.testing._internal.common_utils`:
|
||||
```python
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
class TestFeature(TestCase):
|
||||
def test_something(self):
|
||||
# Use self.assertEqual for tensor comparisons
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
```
|
||||
|
||||
**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file.
|
||||
|
||||
### Linting
|
||||
|
||||
Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes)
|
||||
|
||||
## Project-Specific Conventions
|
||||
|
||||
### Memory and Storage
|
||||
- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs)
|
||||
- CUDA device info lives in storage objects
|
||||
|
||||
### Python-C++ Integration (`torch/csrc/`)
|
||||
- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors
|
||||
- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr`
|
||||
- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion
|
||||
|
||||
### Dispatch System
|
||||
- PyTorch uses operator dispatch to route calls to backend-specific kernels
|
||||
- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops
|
||||
- See `aten/src/ATen/native/README.md` for dispatch keyword guidance
|
||||
|
||||
## Git Workflow (AI Agent Specific)
|
||||
|
||||
When preparing PRs from this environment:
|
||||
```bash
|
||||
git stash -u
|
||||
git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch
|
||||
git stash pop
|
||||
# Resolve conflicts if necessary
|
||||
```
|
||||
|
||||
## Common Gotchas
|
||||
|
||||
1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml`
|
||||
2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI
|
||||
3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux)
|
||||
4. **No internet access** - DO NOT attempt to install dependencies during development
|
||||
|
||||
## Key Files Reference
|
||||
|
||||
- `AGENTS.md` - Instructions specific to AI coding agents
|
||||
- `CONTRIBUTING.md` - Comprehensive human contributor guide
|
||||
- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript)
|
||||
- `aten/src/ATen/native/README.md` - Operator implementation guide
|
||||
- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd
|
||||
|
||||
## Performance Debugging
|
||||
|
||||
Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation.
|
||||
22
.github/labeler.yml
vendored
22
.github/labeler.yml
vendored
@ -138,8 +138,7 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -149,8 +148,7 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -160,21 +158,7 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
"ciflow/mps":
|
||||
- aten/src/ATen/mps/**
|
||||
- aten/src/ATen/native/mps/**
|
||||
- torch/_inductor/codegen/mps.py
|
||||
- test/test_mps.py
|
||||
- test/inductor/test_mps_basic.py
|
||||
|
||||
"ciflow/h100-symm-mem":
|
||||
- torch/csrc/distributed/c10d/symm_mem/**
|
||||
- torch/distributed/_symmetric_memory/**
|
||||
- test/distributed/**/*mem*
|
||||
- test/distributed/**/*mem*/**
|
||||
|
||||
1
.github/nitpicks.yml
vendored
1
.github/nitpicks.yml
vendored
@ -10,4 +10,3 @@
|
||||
pathFilter:
|
||||
- 'torch/csrc/inductor/aoti_torch/c/*'
|
||||
- 'torch/csrc/inductor/aoti_torch/generated/*'
|
||||
- 'torch/csrc/stable/c/*'
|
||||
|
||||
6
.github/pytorch-probot.yml
vendored
6
.github/pytorch-probot.yml
vendored
@ -2,8 +2,8 @@ tracking_issue: 24422
|
||||
ciflow_tracking_issue: 64124
|
||||
ciflow_push_tags:
|
||||
- ciflow/b200
|
||||
- ciflow/b200-distributed
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/b200-distributed
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
@ -22,8 +22,6 @@ ciflow_push_tags:
|
||||
- ciflow/inductor-perf-test-nightly-xpu
|
||||
- ciflow/inductor-periodic
|
||||
- ciflow/inductor-rocm
|
||||
- ciflow/inductor-rocm-mi200
|
||||
- ciflow/inductor-rocm-mi300
|
||||
- ciflow/linux-aarch64
|
||||
- ciflow/mps
|
||||
- ciflow/nightly
|
||||
@ -35,13 +33,11 @@ ciflow_push_tags:
|
||||
- ciflow/quantization-periodic
|
||||
- ciflow/riscv64
|
||||
- ciflow/rocm
|
||||
- ciflow/rocm-mi200
|
||||
- ciflow/rocm-mi300
|
||||
- ciflow/rocm-mi355
|
||||
- ciflow/rocm-navi31
|
||||
- ciflow/s390
|
||||
- ciflow/slow
|
||||
- ciflow/slow-rocm-mi200
|
||||
- ciflow/torchbench
|
||||
- ciflow/triton_binaries
|
||||
- ciflow/trunk
|
||||
|
||||
3
.github/scripts/delete_old_branches.py
vendored
3
.github/scripts/delete_old_branches.py
vendored
@ -1,11 +1,10 @@
|
||||
# Delete old branches
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from github_utils import gh_fetch_json_dict, gh_graphql
|
||||
from gitutils import GitRepo
|
||||
|
||||
3
.github/scripts/filter_test_configs.py
vendored
3
.github/scripts/filter_test_configs.py
vendored
@ -8,11 +8,10 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import cache
|
||||
from logging import info
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import yaml
|
||||
|
||||
3
.github/scripts/get_workflow_job_id.py
vendored
3
.github/scripts/get_workflow_job_id.py
vendored
@ -11,8 +11,7 @@ import sys
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
|
||||
3
.github/scripts/github_utils.py
vendored
3
.github/scripts/github_utils.py
vendored
@ -3,9 +3,8 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional, Union
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
4
.github/scripts/gitutils.py
vendored
4
.github/scripts/gitutils.py
vendored
@ -4,10 +4,10 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
3
.github/scripts/lintrunner.sh
vendored
3
.github/scripts/lintrunner.sh
vendored
@ -34,9 +34,6 @@ python3 torch/utils/data/datapipes/gen_pyi.py
|
||||
# Also check generated pyi files
|
||||
find torch -name '*.pyi' -exec git add --force -- "{}" +
|
||||
|
||||
# Print current environment
|
||||
python3 -m pip freeze
|
||||
|
||||
RC=0
|
||||
# Run lintrunner on all files
|
||||
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
|
||||
|
||||
4
.github/scripts/trymerge.py
vendored
4
.github/scripts/trymerge.py
vendored
@ -17,12 +17,12 @@ import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||
from warnings import warn
|
||||
|
||||
import yaml
|
||||
|
||||
4
.github/workflows/_rocm-test.yml
vendored
4
.github/workflows/_rocm-test.yml
vendored
@ -97,8 +97,8 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
|
||||
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
|
||||
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
|
||||
if [[ $ngpu -lt 4 ]]; then
|
||||
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
16
.github/workflows/_xpu-test.yml
vendored
16
.github/workflows/_xpu-test.yml
vendored
@ -344,21 +344,5 @@ jobs:
|
||||
if-no-files-found: ignore
|
||||
path: ./**/core.[1-9]*
|
||||
|
||||
- name: Authenticate with AWS
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
|
||||
# The max duration enforced by the server side
|
||||
role-duration-seconds: 18000
|
||||
aws-region: us-east-1
|
||||
|
||||
- name: Upload the benchmark results
|
||||
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
|
||||
with:
|
||||
benchmark-results-dir: test/test-reports
|
||||
dry-run: false
|
||||
schema-version: v3
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Teardown XPU
|
||||
uses: ./.github/actions/teardown-xpu
|
||||
|
||||
1
.github/workflows/b200-distributed.yml
vendored
1
.github/workflows/b200-distributed.yml
vendored
@ -37,6 +37,7 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
|
||||
1
.github/workflows/b200-symm-mem.yml
vendored
1
.github/workflows/b200-symm-mem.yml
vendored
@ -37,6 +37,7 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
|
||||
6
.github/workflows/build-triton-wheel.yml
vendored
6
.github/workflows/build-triton-wheel.yml
vendored
@ -52,15 +52,11 @@ jobs:
|
||||
matrix:
|
||||
py_vers: [ "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t" ]
|
||||
device: ["cuda", "rocm", "xpu", "aarch64"]
|
||||
rocm_version: [""]
|
||||
docker-image: ["pytorch/manylinux2_28-builder:cpu"]
|
||||
include:
|
||||
- device: "rocm"
|
||||
rocm_version: "7.1"
|
||||
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
||||
- device: "rocm"
|
||||
rocm_version: "7.0"
|
||||
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
||||
- device: "cuda"
|
||||
rocm_version: ""
|
||||
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
||||
@ -175,7 +171,7 @@ jobs:
|
||||
|
||||
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
|
||||
with:
|
||||
name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }}${{ matrix.rocm_version != '' && format('-{0}', matrix.rocm_version) || '' }}-${{ env.PLATFORM }}
|
||||
name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }}-${{ env.PLATFORM }}
|
||||
if-no-files-found: error
|
||||
path: ${{ runner.temp }}/artifacts/wheelhouse/*
|
||||
|
||||
|
||||
13
.github/workflows/docker-builds.yml
vendored
13
.github/workflows/docker-builds.yml
vendored
@ -56,8 +56,6 @@ jobs:
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
|
||||
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
|
||||
pytorch-linux-jammy-py3.10-clang12,
|
||||
pytorch-linux-jammy-py3.11-clang12,
|
||||
pytorch-linux-jammy-py3.12-clang12,
|
||||
pytorch-linux-jammy-py3.13-clang12,
|
||||
pytorch-linux-jammy-py3.14-clang12,
|
||||
pytorch-linux-jammy-rocm-n-py3,
|
||||
@ -67,10 +65,9 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-jammy-xpu-n-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3-clang18-asan,
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
@ -80,11 +77,9 @@ jobs:
|
||||
pytorch-linux-noble-riscv64-py3.12-gcc14
|
||||
]
|
||||
include:
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
|
||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
timeout-minutes: 600
|
||||
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358
|
||||
|
||||
1
.github/workflows/h100-distributed.yml
vendored
1
.github/workflows/h100-distributed.yml
vendored
@ -37,6 +37,7 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: "linux.c7i.12xlarge"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
@ -72,7 +72,7 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
build-environment: linux-jammy-aarch64-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },
|
||||
|
||||
@ -83,8 +83,8 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
|
||||
1
.github/workflows/inductor-rocm-mi300.yml
vendored
1
.github/workflows/inductor-rocm-mi300.yml
vendored
@ -7,7 +7,6 @@ on:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/inductor-rocm/*
|
||||
- ciflow/inductor-rocm-mi300/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
|
||||
@ -2,12 +2,12 @@ name: inductor-rocm
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 */3 * * *
|
||||
- cron: 0 * * * *
|
||||
push:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/inductor-rocm-mi200/*
|
||||
- ciflow/inductor-rocm/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
26
.github/workflows/inductor-unittest.yml
vendored
26
.github/workflows/inductor-unittest.yml
vendored
@ -81,32 +81,6 @@ jobs:
|
||||
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-build:
|
||||
name: inductor-pallas-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-test:
|
||||
name: inductor-pallas-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: inductor-pallas-build
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-triton-cpu-build:
|
||||
name: inductor-triton-cpu-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
2
.github/workflows/linux-aarch64.yml
vendored
2
.github/workflows/linux-aarch64.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
||||
with:
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-aarch64-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
|
||||
8
.github/workflows/nightly.yml
vendored
8
.github/workflows/nightly.yml
vendored
@ -5,11 +5,9 @@ on:
|
||||
- cron: 0 0 * * *
|
||||
push:
|
||||
tags:
|
||||
# NOTE: Doc build pipelines should only get triggered on:
|
||||
# Major or minor release candidates builds
|
||||
- v[0-9]+.[0-9]+.0+-rc[0-9]+
|
||||
# Final RC for major, minor and patch releases
|
||||
- v[0-9]+.[0-9]+.[0-9]+
|
||||
# NOTE: Doc build pipelines should only get triggered on release candidate builds
|
||||
# Release candidate tags look like: v1.11.0-rc1
|
||||
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
|
||||
- ciflow/nightly/*
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
2
.github/workflows/operator_benchmark.yml
vendored
2
.github/workflows/operator_benchmark.yml
vendored
@ -60,7 +60,7 @@ jobs:
|
||||
with:
|
||||
build-environment: linux-jammy-aarch64-py3.10
|
||||
runner: linux.arm64.m7g.4xlarge
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
||||
|
||||
1
.github/workflows/periodic-rocm-mi200.yml
vendored
1
.github/workflows/periodic-rocm-mi200.yml
vendored
@ -11,6 +11,7 @@ on:
|
||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||
push:
|
||||
tags:
|
||||
- ciflow/periodic/*
|
||||
- ciflow/periodic-rocm-mi200/*
|
||||
branches:
|
||||
- release/*
|
||||
|
||||
1
.github/workflows/periodic-rocm-mi300.yml
vendored
1
.github/workflows/periodic-rocm-mi300.yml
vendored
@ -11,7 +11,6 @@ on:
|
||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||
push:
|
||||
tags:
|
||||
- ciflow/periodic/*
|
||||
- ciflow/periodic-rocm-mi300/*
|
||||
branches:
|
||||
- release/*
|
||||
|
||||
8
.github/workflows/pull.yml
vendored
8
.github/workflows/pull.yml
vendored
@ -342,16 +342,16 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-noble-xpu-n-py3_10-build:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
linux-jammy-xpu-n-py3_10-build:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
||||
# sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },
|
||||
|
||||
1
.github/workflows/rocm-mi300.yml
vendored
1
.github/workflows/rocm-mi300.yml
vendored
@ -6,7 +6,6 @@ on:
|
||||
- main
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/rocm/*
|
||||
- ciflow/rocm-mi300/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
|
||||
@ -5,12 +5,11 @@ on:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/rocm-mi200/*
|
||||
- ciflow/rocm/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
- cron: 0 */3 * * *
|
||||
|
||||
- cron: 0 * * * *
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
81
.github/workflows/slow-rocm-mi200.yml
vendored
81
.github/workflows/slow-rocm-mi200.yml
vendored
@ -1,81 +0,0 @@
|
||||
# This workflow is dedicated to host slow jobs that are run only periodically because
|
||||
# they are too slow to run in every commit. The list of slow tests can be found in
|
||||
# https://github.com/pytorch/test-infra/blob/generated-stats/stats/slow-tests.json
|
||||
name: slow-rocm-mi200
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/slow/*
|
||||
- ciflow/slow-rocm-mi200/*
|
||||
schedule:
|
||||
- cron: 0 */3 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
llm-td:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: before-test
|
||||
uses: ./.github/workflows/llm_td_retrieval.yml
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
target-determination:
|
||||
name: before-test
|
||||
uses: ./.github/workflows/target_determination.yml
|
||||
needs: llm-td
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
30
.github/workflows/slow.yml
vendored
30
.github/workflows/slow.yml
vendored
@ -105,6 +105,36 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3_10-clang18-asan-build:
|
||||
name: linux-jammy-py3.10-clang18-asan
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
3
.github/workflows/test-b200.yml
vendored
3
.github/workflows/test-b200.yml
vendored
@ -52,6 +52,7 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '10.0'
|
||||
@ -72,4 +73,4 @@ jobs:
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm100-build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
secrets: inherit
|
||||
secrets: inherit
|
||||
1
.github/workflows/test-h100.yml
vendored
1
.github/workflows/test-h100.yml
vendored
@ -41,6 +41,7 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: '9.0'
|
||||
|
||||
5
.github/workflows/upload-test-stats.yml
vendored
5
.github/workflows/upload-test-stats.yml
vendored
@ -11,16 +11,15 @@ on:
|
||||
- inductor
|
||||
- unstable
|
||||
- slow
|
||||
- slow-rocm-mi200
|
||||
- unstable-periodic
|
||||
- inductor-periodic
|
||||
- rocm-mi200
|
||||
- rocm
|
||||
- rocm-mi300
|
||||
- rocm-mi355
|
||||
- inductor-micro-benchmark
|
||||
- inductor-micro-benchmark-x86
|
||||
- inductor-cu124
|
||||
- inductor-rocm-mi200
|
||||
- inductor-rocm
|
||||
- inductor-rocm-mi300
|
||||
- mac-mps
|
||||
- linux-aarch64
|
||||
|
||||
20
.github/workflows/xpu.yml
vendored
20
.github/workflows/xpu.yml
vendored
@ -47,15 +47,15 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-noble-xpu-n-py3_10-build:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
linux-jammy-xpu-n-py3_10-build:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -74,17 +74,17 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-noble-xpu-n-py3_10-test:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
linux-jammy-xpu-n-py3_10-test:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: linux-noble-xpu-n-py3_10-build
|
||||
needs: linux-jammy-xpu-n-py3_10-build
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
with:
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
windows-xpu-n-1-build:
|
||||
|
||||
@ -143,8 +143,7 @@ init_command = [
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
|
||||
'numpy==2.3.4 ; python_version >= "3.14"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'expecttest==0.3.0',
|
||||
'pyrefly==0.36.2',
|
||||
'sympy==1.13.3',
|
||||
@ -186,8 +185,6 @@ include_patterns = [
|
||||
'aten/src/ATen/native/nested/cuda/*.h',
|
||||
'aten/src/ATen/native/nested/*.cpp',
|
||||
'aten/src/ATen/native/nested/*.h',
|
||||
'aten/src/ATen/xpu/**/*.h',
|
||||
'aten/src/ATen/xpu/**/*.cpp',
|
||||
'c10/**/*.cpp',
|
||||
'c10/**/*.h',
|
||||
'torch/*.h',
|
||||
@ -1404,7 +1401,7 @@ init_command = [
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'usort==1.0.8.post1',
|
||||
'isort==6.0.1',
|
||||
'ruff==0.14.4', # sync with RUFF
|
||||
'ruff==0.13.1', # sync with RUFF
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
@ -1539,7 +1536,7 @@ init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.14.4', # sync with PYFMT
|
||||
'ruff==0.13.1', # sync with PYFMT
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
||||
106
CMakeLists.txt
106
CMakeLists.txt
@ -234,17 +234,7 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON)
|
||||
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
|
||||
option(USE_LSAN "Use Leak Sanitizer" OFF)
|
||||
option(USE_TSAN "Use Thread Sanitizer" OFF)
|
||||
|
||||
# Track whether USE_CUDA was explicitly set by the user (before option() is called)
|
||||
# If USE_CUDA is already defined in cache, it means user explicitly set it
|
||||
if(DEFINED CACHE{USE_CUDA})
|
||||
set(_USE_CUDA_EXPLICITLY_SET TRUE)
|
||||
else()
|
||||
set(_USE_CUDA_EXPLICITLY_SET FALSE)
|
||||
endif()
|
||||
|
||||
option(USE_CUDA "Use CUDA" ON)
|
||||
|
||||
option(USE_XPU "Use XPU" ON)
|
||||
cmake_dependent_option(
|
||||
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
||||
@ -736,44 +726,6 @@ if(NOT DEFINED USE_BLAS)
|
||||
set(USE_BLAS ON)
|
||||
endif()
|
||||
|
||||
# Prioritized Text Linker Optimization
|
||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||
|
||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE}
|
||||
${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py
|
||||
--filein "${LINKER_SCRIPT_FILE_IN}"
|
||||
--fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||
RESULT_VARIABLE _gen_result
|
||||
OUTPUT_VARIABLE _gen_output
|
||||
ERROR_VARIABLE _gen_error
|
||||
)
|
||||
|
||||
if(NOT _gen_result EQUAL 0)
|
||||
message(FATAL_ERROR
|
||||
"Failed to generate linker script:\n${_gen_output}\n${_gen_error}")
|
||||
endif()
|
||||
|
||||
append_cxx_flag_if_supported("-ffunction-sections" CMAKE_CXX_FLAGS)
|
||||
append_cxx_flag_if_supported("-fdata-sections" CMAKE_CXX_FLAGS)
|
||||
append_c_flag_if_supported("-ffunction-sections" CMAKE_C_FLAGS)
|
||||
append_c_flag_if_supported("-fdata-sections" CMAKE_C_FLAGS)
|
||||
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
||||
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
||||
|
||||
else()
|
||||
if(LINUX AND CPU_AARCH64)
|
||||
message(WARNING [[
|
||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
]])
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Build libtorch mobile library, which contains ATen/TH ops and native support
|
||||
# for TorchScript model, but doesn't contain not-yet-unified caffe2 ops;
|
||||
if(INTERN_BUILD_MOBILE)
|
||||
@ -1440,6 +1392,9 @@ if(BUILD_JNI)
|
||||
add_subdirectory(android/pytorch_android)
|
||||
endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
|
||||
# Parse custom debug info
|
||||
if(DEFINED USE_CUSTOM_DEBINFO)
|
||||
string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}")
|
||||
@ -1479,5 +1434,56 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
|
||||
DESTINATION "${CMAKE_INSTALL_BINDIR}")
|
||||
endif()
|
||||
|
||||
include(cmake/Summary.cmake)
|
||||
caffe2_print_configuration_summary()
|
||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||
add_compile_options(
|
||||
$<$<COMPILE_LANGUAGE:C,CXX>:-ffunction-sections>
|
||||
$<$<COMPILE_LANGUAGE:C,CXX>:-fdata-sections>
|
||||
)
|
||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${LINKER_SCRIPT_FILE_OUT}"
|
||||
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||
DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}"
|
||||
COMMENT "Generating prioritized text linker files"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||
|
||||
if(BUILD_PYTHON)
|
||||
set(LINKER_OPT_TARGETS torch_python)
|
||||
endif()
|
||||
|
||||
if(NOT BUILD_LIBTORCHLESS)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_cpu c10)
|
||||
if(USE_CUDA)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda)
|
||||
endif()
|
||||
if(USE_XPU)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu)
|
||||
endif()
|
||||
if(USE_ROCM)
|
||||
list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
foreach(tgt IN LISTS LINKER_OPT_TARGETS)
|
||||
if(TARGET ${tgt})
|
||||
add_dependencies("${tgt}" generate_linker_script)
|
||||
target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}")
|
||||
set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||
else()
|
||||
message(WARNING "Requested target '${tgt}' for linker script optimization was not found.")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
else()
|
||||
if(LINUX AND CPU_AARCH64)
|
||||
message(WARNING [[
|
||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||
]])
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -210,12 +210,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
||||
/test/inductor/test_flex_attention.py @drisspg
|
||||
/test/inductor/test_flex_decoding.py @drisspg
|
||||
|
||||
# Low Precision & Grouped GEMMs
|
||||
# Low Precision GEMMs
|
||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
|
||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
||||
|
||||
@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
|
||||
- [Python Unit Testing](#python-unit-testing)
|
||||
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
||||
- [Local linting](#local-linting)
|
||||
- [Running `pyrefly`](#running-pyrefly)
|
||||
- [Running `mypy`](#running-mypy)
|
||||
- [C++ Unit Testing](#c-unit-testing)
|
||||
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
||||
- [Merging your Change](#merging-your-change)
|
||||
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
|
||||
**Prerequisites**:
|
||||
The following packages should be installed with `pip`:
|
||||
- `expecttest` and `hypothesis` - required to run tests
|
||||
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
|
||||
- `mypy` - recommended for linting
|
||||
- `pytest` - recommended to run tests more selectively
|
||||
Running
|
||||
```
|
||||
@ -350,32 +350,15 @@ make lint
|
||||
|
||||
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
||||
|
||||
#### Running `pyrefly`
|
||||
#### Running `mypy`
|
||||
|
||||
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
|
||||
|
||||
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
|
||||
|
||||
**Getting Started with Pyrefly:**
|
||||
|
||||
To run type checking on the PyTorch codebase:
|
||||
```bash
|
||||
pyrefly check
|
||||
```
|
||||
|
||||
For more detailed error information with summaries:
|
||||
```bash
|
||||
pyrefly check --summarize-errors
|
||||
```
|
||||
|
||||
**Learn More:**
|
||||
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
|
||||
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
|
||||
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
|
||||
`mypy` is an optional static type checker for Python. We have multiple `mypy`
|
||||
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
|
||||
|
||||
See [Guide for adding type annotations to
|
||||
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
||||
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
|
||||
for more information on how to set up `mypy` and tackle type annotation
|
||||
tasks.
|
||||
|
||||
### C++ Unit Testing
|
||||
|
||||
|
||||
2
LICENSE
2
LICENSE
@ -37,7 +37,7 @@ Copyright (c) 2024 Tri Dao.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Arm:
|
||||
Copyright (c) 2021, 2023-2025 Arm Limited and/or its affiliates
|
||||
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
|
||||
@ -174,12 +174,6 @@ class TORCH_API Context {
|
||||
static long versionCuDNN() {
|
||||
return detail::getCUDAHooks().versionCuDNN();
|
||||
}
|
||||
static long versionRuntimeCuDNN() {
|
||||
return detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
}
|
||||
static long versionCuDNNFrontend() {
|
||||
return detail::getCUDAHooks().versionCuDNNFrontend();
|
||||
}
|
||||
static bool hasCuSOLVER() {
|
||||
return detail::getCUDAHooks().hasCuSOLVER();
|
||||
}
|
||||
|
||||
@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda.h> // For CUDA_VERSION
|
||||
@ -62,9 +61,12 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
|
||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
case enum_type: { \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
||||
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
#define AT_DISPATCH_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
||||
@ -93,6 +95,14 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// The AT_DISPATCH_* family of macros provides the ability to
|
||||
// conveniently generate specializations of a kernel over all of the
|
||||
// dtypes we care about in PyTorch. We call it "dispatch" because
|
||||
@ -180,13 +190,27 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
// but we're just being safe (and it doesn't hurt.) Note we must
|
||||
// use it to shut up warnings about unused store.
|
||||
|
||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH_TMPL( \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE, \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
__VA_ARGS__)
|
||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const auto& the_type = TYPE; \
|
||||
constexpr const char* at_dispatch_name = NAME; \
|
||||
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED( \
|
||||
false, \
|
||||
'"', \
|
||||
at_dispatch_name, \
|
||||
"\" not implemented for '", \
|
||||
toString(_st), \
|
||||
"'"); \
|
||||
} \
|
||||
C10_DIAGNOSTIC_POP() \
|
||||
}()
|
||||
|
||||
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
|
||||
@ -1,8 +1,3 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
|
||||
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
// This is a new implementation of the AT_DISPATCH macro family from
|
||||
@ -79,19 +74,41 @@
|
||||
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
||||
// relied on GPT4 to help me get it right.
|
||||
|
||||
// Public API macros
|
||||
|
||||
// See documentation above
|
||||
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||
THO_DISPATCH_V2_TMPL( \
|
||||
AT_DISPATCH_SWITCH, \
|
||||
AT_DISPATCH_CASE, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(BODY), \
|
||||
__VA_ARGS__)
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
||||
|
||||
// This macro lets you pass an arbitrary expression that may contain internal
|
||||
// commas to another macro without having the commas causing the expression
|
||||
// to be interpreted as being multiple arguments
|
||||
#define AT_WRAP(...) __VA_ARGS__
|
||||
|
||||
#define AT_FLOAT8_TYPES \
|
||||
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
||||
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
|
||||
|
||||
#define AT_INTEGRAL_TYPES \
|
||||
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
||||
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
||||
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
||||
#define AT_INTEGRAL_TYPES_V2 \
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
||||
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
||||
// NB: not *actually* all types
|
||||
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
#define AT_ALL_TYPES_AND_COMPLEX \
|
||||
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
||||
|
||||
// Helper macros
|
||||
|
||||
// Unused helper macros, kept for BC:
|
||||
#define AT_AP_VAR(N, T, ...) \
|
||||
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
||||
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
||||
#define AT_CONCAT_AUX(a, b) a##b
|
||||
#define AT_EXPAND(X) X
|
||||
|
||||
// Ensure we never have too many scalar types for the expansion here to
|
||||
// support. To bump this, you must regenerate the macros below.
|
||||
@ -102,6 +119,12 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
|
||||
|
||||
num_args = 60
|
||||
|
||||
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
||||
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
||||
|
||||
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
||||
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
||||
|
||||
for i in range(1, num_args+1):
|
||||
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
||||
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
||||
@ -112,6 +135,8 @@ for i in range(1, num_args+1):
|
||||
// Begin generated code
|
||||
// clang-format off
|
||||
|
||||
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
|
||||
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
||||
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
||||
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
||||
|
||||
@ -226,8 +226,8 @@ template <
|
||||
typename B = HostBlock<S>>
|
||||
struct CachingHostAllocatorImpl {
|
||||
virtual ~CachingHostAllocatorImpl() {
|
||||
if (active_) {
|
||||
active_ = false;
|
||||
active_ = false;
|
||||
if (pinned_use_background_threads()) {
|
||||
getBackgroundThreadPool()->waitWorkComplete();
|
||||
}
|
||||
}
|
||||
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
|
||||
if (pinned_use_background_threads()) {
|
||||
// Launch the background thread and process events in a loop.
|
||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||
active_ = true;
|
||||
getBackgroundThreadPool()->run([&]() {
|
||||
while (active_) {
|
||||
process_events();
|
||||
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the event-processing thread pool is active.
|
||||
// Indicates whether the object is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{false};
|
||||
std::atomic<bool> active_{true};
|
||||
protected:
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
|
||||
auto vals = svreinterpret_u16_bf16(values);
|
||||
vals = sveor_u16_x(ptrue, vals, mask);
|
||||
return svreinterpret_bf16_u16(vals);
|
||||
}
|
||||
};
|
||||
Vectorized<BFloat16> round() const;
|
||||
Vectorized<BFloat16> tan() const;
|
||||
Vectorized<BFloat16> tanh() const;
|
||||
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
|
||||
return convert_float_bfloat16(v1, v2); \
|
||||
}
|
||||
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(round)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow)
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(round);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
|
||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
|
||||
|
||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
|
||||
const Vectorized<BFloat16>& other) const {
|
||||
|
||||
@ -388,7 +388,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
uint32_t mask = -1;
|
||||
#endif
|
||||
void * alpha_ptr = α
|
||||
void * beta_ptr = β
|
||||
@ -428,7 +427,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||
if (fp16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
mask =
|
||||
uint32_t mask =
|
||||
fp16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -445,7 +444,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||
if (bf16_reduction !=
|
||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||
mask =
|
||||
uint32_t mask =
|
||||
bf16_reduction ==
|
||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||
@ -512,41 +511,17 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
|
||||
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
|
||||
#ifndef USE_ROCM
|
||||
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
|
||||
#else
|
||||
const bool lie_to_cublaslt = false;
|
||||
#endif
|
||||
if (lie_to_cublaslt) {
|
||||
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
|
||||
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
|
||||
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
FakeBdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
FakeCdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
} else {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
}
|
||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
Adesc.descriptor(),
|
||||
Bdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
Cdesc.descriptor(),
|
||||
preference.descriptor(),
|
||||
1,
|
||||
&heuristicResult,
|
||||
&returnedResult));
|
||||
if (returnedResult == 0) {
|
||||
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
||||
}
|
||||
@ -1597,7 +1572,7 @@ bool gemm_and_bias(
|
||||
}
|
||||
|
||||
using opmath_t = at::opmath_type<Dtype>;
|
||||
opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr
|
||||
opmath_t beta_val = 0; // bias is added in epilogue
|
||||
|
||||
cudaDataType_t abType = CUDA_R_32F;
|
||||
cudaDataType_t cType = CUDA_R_32F;
|
||||
@ -1686,22 +1661,15 @@ bool gemm_and_bias(
|
||||
_syncCurrentWithCarveoutStream(stream, true);
|
||||
}
|
||||
#endif
|
||||
const auto epilogue = [&]() -> cublasLtEpilogue_t {
|
||||
// The cuBLAS documentation indicates that
|
||||
// *_<ACTIVATION>_BIAS = *_<ACTIVATION>,
|
||||
// but we keep it verbose here for clarity.
|
||||
switch (activation) {
|
||||
case GEMMAndBiasActivationEpilogue::RELU:
|
||||
return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU;
|
||||
case GEMMAndBiasActivationEpilogue::GELU:
|
||||
return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU;
|
||||
default:
|
||||
return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT;
|
||||
}
|
||||
}();
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
|
||||
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
|
||||
}
|
||||
|
||||
if (bias) {
|
||||
if (bias != nullptr) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
|
||||
}
|
||||
|
||||
|
||||
@ -24,13 +24,7 @@ namespace detail {
|
||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||
// the data, so we can save template instantiations by reinterpreting
|
||||
// it as an opaque type.
|
||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
||||
|
||||
template<typename key_t, int value_size>
|
||||
void radix_sort_pairs_impl(
|
||||
|
||||
@ -21,7 +21,6 @@
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||
#include <cudnn_frontend.h>
|
||||
#endif
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
@ -352,26 +351,6 @@ long CUDAHooks::versionCuDNN() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionRuntimeCuDNN() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#ifndef USE_STATIC_CUDNN
|
||||
return cudnnGetVersion();
|
||||
#else
|
||||
return CUDNN_VERSION;
|
||||
#endif
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionCuDNNFrontend() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
return CUDNN_FRONTEND_VERSION;
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionMIOpen() const {
|
||||
#if AT_ROCM_ENABLED()
|
||||
return MIOPEN_VERSION_MAJOR * 10000 +
|
||||
|
||||
@ -49,8 +49,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCUDART() const override;
|
||||
long versionCUDART() const override;
|
||||
long versionCuDNN() const override;
|
||||
long versionRuntimeCuDNN() const override;
|
||||
long versionCuDNNFrontend() const override;
|
||||
long versionMIOpen() const override;
|
||||
std::string showConfig() const override;
|
||||
double batchnormMinEpsilonCuDNN() const override;
|
||||
|
||||
@ -174,14 +174,6 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionRuntimeCuDNN() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionCuDNNFrontend() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionMIOpen() const {
|
||||
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
@ -157,8 +157,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
||||
DispatchKey::Negative,
|
||||
DispatchKey::Conjugate,
|
||||
DispatchKey::XLA,
|
||||
DispatchKey::XPU,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
|
||||
@ -440,7 +440,7 @@ bool MPSHeapAllocatorImpl::release_cached_buffers() {
|
||||
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
|
||||
m_mutex.unlock();
|
||||
auto stream = getDefaultMPSStream();
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
dispatch_sync(stream->queue(), ^() {
|
||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
});
|
||||
m_mutex.lock();
|
||||
|
||||
@ -110,9 +110,6 @@ class TORCH_API MPSStream {
|
||||
return _stream;
|
||||
}
|
||||
|
||||
MTLBuffer_t getErrorBuffer();
|
||||
void checkLastError();
|
||||
|
||||
private:
|
||||
Stream _stream;
|
||||
MTLCommandQueue_t _commandQueue = nil;
|
||||
@ -124,8 +121,6 @@ class TORCH_API MPSStream {
|
||||
dispatch_queue_t _serialQueue = nullptr;
|
||||
// CommitAndContinue is enabled by default
|
||||
bool _enableCommitAndContinue = true;
|
||||
// Buffer that contains last raised error
|
||||
MTLBuffer_t _errorBuffer = nil;
|
||||
|
||||
// use synchronize() to access any of these commit functions outside MPSStream
|
||||
void commit();
|
||||
@ -160,7 +155,4 @@ class TORCH_API MPSStreamImpl {
|
||||
MPSStreamImpl();
|
||||
};
|
||||
|
||||
#ifdef __OBJC__
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||
#endif
|
||||
} // namespace at::mps
|
||||
|
||||
@ -3,13 +3,13 @@
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <c10/metal/error.h>
|
||||
|
||||
@interface MPSGraphExecutionDescriptor ()
|
||||
@property(readwrite, atomic) BOOL enableCommitAndContinue;
|
||||
@end
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSStream
|
||||
//-----------------------------------------------------------------
|
||||
@ -30,10 +30,6 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
|
||||
// Choose level which optimizes for GPU
|
||||
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
|
||||
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
|
||||
|
||||
_errorBuffer = [MPSDevice::getInstance()->device() newBufferWithLength:sizeof(c10::metal::ErrorMessages)
|
||||
options:MTLResourceStorageModeShared];
|
||||
std::memset([_errorBuffer contents], 0, 1024);
|
||||
}
|
||||
|
||||
MPSStream::~MPSStream() {
|
||||
@ -42,8 +38,6 @@ MPSStream::~MPSStream() {
|
||||
[_executionDescriptor release];
|
||||
[_compilationDescriptor release];
|
||||
_executionDescriptor = nil;
|
||||
[_errorBuffer release];
|
||||
_errorBuffer = nil;
|
||||
_compilationDescriptor = nil;
|
||||
|
||||
assert(_commandBuffer == nil);
|
||||
@ -110,7 +104,6 @@ void MPSStream::commitAndWait() {
|
||||
[_prevCommandBuffer waitUntilCompleted];
|
||||
[_prevCommandBuffer release];
|
||||
_prevCommandBuffer = nil;
|
||||
checkLastError();
|
||||
}
|
||||
|
||||
if (_commandBuffer) {
|
||||
@ -118,7 +111,6 @@ void MPSStream::commitAndWait() {
|
||||
[_commandBuffer waitUntilCompleted];
|
||||
[_commandBuffer release];
|
||||
_commandBuffer = nil;
|
||||
checkLastError();
|
||||
}
|
||||
}
|
||||
|
||||
@ -161,7 +153,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
|
||||
if (length == 0) {
|
||||
return;
|
||||
}
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
endKernelCoalescing();
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
@ -191,7 +183,7 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer,
|
||||
size_t dstOffset,
|
||||
uint64_t profileId,
|
||||
SyncType syncType) {
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
@autoreleasepool {
|
||||
endKernelCoalescing();
|
||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||
@ -244,7 +236,7 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
||||
auto& profiler = getMPSProfiler();
|
||||
const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled();
|
||||
|
||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
||||
dispatch_sync(_serialQueue, ^() {
|
||||
endKernelCoalescing();
|
||||
if (isGraphProfilingEnabled) {
|
||||
// this function call is only relevant for interval-based Signposts
|
||||
@ -274,24 +266,6 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
||||
});
|
||||
}
|
||||
|
||||
id<MTLBuffer> MPSStream::getErrorBuffer() {
|
||||
return _errorBuffer;
|
||||
}
|
||||
|
||||
void MPSStream::checkLastError() {
|
||||
auto msgs = reinterpret_cast<c10::metal::ErrorMessages*>([_errorBuffer contents]);
|
||||
const auto& msg = msgs->msg[0];
|
||||
if (!msgs) {
|
||||
return;
|
||||
}
|
||||
unsigned int count = 0;
|
||||
std::swap(count, msgs->count);
|
||||
if (!count) {
|
||||
return;
|
||||
}
|
||||
throw c10::AcceleratorError({msg.func, msg.file, msg.line}, 1, msg.message);
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
// MPSStreamImpl
|
||||
//-----------------------------------------------------------------
|
||||
@ -315,19 +289,4 @@ MPSStream* getDefaultMPSStream() {
|
||||
return MPSStreamImpl::getInstance();
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
|
||||
__block std::optional<std::exception_ptr> block_exception;
|
||||
dispatch_sync(queue, ^() {
|
||||
try {
|
||||
block();
|
||||
} catch (...) {
|
||||
block_exception = std::current_exception();
|
||||
}
|
||||
});
|
||||
if (block_exception) {
|
||||
std::rethrow_exception(*block_exception);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::mps
|
||||
|
||||
@ -1009,25 +1009,12 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) {
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor send_to_meta(const Tensor& self, const Device& device) {
|
||||
Tensor out_meta;
|
||||
if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device));
|
||||
out_meta.unsafeGetTensorImpl()->set_wrapped_number(true);
|
||||
} else {
|
||||
out_meta = self.to(device);
|
||||
}
|
||||
return out_meta;
|
||||
}
|
||||
|
||||
Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
|
||||
auto out_device = correct_out_device(self, other);
|
||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||
auto device_ = Device(DeviceType::Meta);
|
||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||
auto self_meta = send_to_meta(self, device_);
|
||||
auto other_meta = send_to_meta(other, device_);
|
||||
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta);
|
||||
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
|
||||
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
|
||||
}
|
||||
|
||||
@ -1036,9 +1023,7 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) {
|
||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||
auto device_ = Device(DeviceType::Meta);
|
||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||
auto self_meta = send_to_meta(self, device_);
|
||||
auto other_meta = send_to_meta(other, device_);
|
||||
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta);
|
||||
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
|
||||
|
||||
if (self._is_zerotensor()) {
|
||||
if (other._is_zerotensor()) {
|
||||
@ -1067,9 +1052,8 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const
|
||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||
auto device_ = Device(DeviceType::Meta);
|
||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||
auto self_meta = send_to_meta(self, device_);
|
||||
auto other_meta = send_to_meta(other, device_);
|
||||
auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha);
|
||||
auto meta_out = at::_ops::add_Tensor::redispatch(
|
||||
meta_dks, self.to(device_), other.to(device_), alpha);
|
||||
|
||||
auto get_out_like = [&] (const Tensor& tensor)
|
||||
{
|
||||
|
||||
@ -409,7 +409,7 @@ struct ConvParams {
|
||||
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
|
||||
return false;
|
||||
}
|
||||
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
|
||||
// broken on cuDNN 9.8 - 9.14
|
||||
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
||||
@ -453,7 +453,7 @@ struct ConvParams {
|
||||
}
|
||||
// native kernel doesn't support 64-bit non-splittable case
|
||||
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
|
||||
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
|
||||
if (cudnn_version < 0 || cudnn_version > 91000) {
|
||||
|
||||
@ -50,35 +50,18 @@ static inline bool parseLinearFlatten3d() {
|
||||
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
|
||||
// before passing it to linear operation
|
||||
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
|
||||
const auto input_sizes = input.sym_sizes();
|
||||
|
||||
const auto result_flattened = [&]() -> Tensor {
|
||||
const auto input_ncols = input_sizes.back();
|
||||
const auto input_flattened_nrows = [&]() -> c10::SymInt {
|
||||
// can't use -1 in reshape because it errors when a dimension is 0
|
||||
auto flattened_nrows = c10::SymInt{1};
|
||||
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
|
||||
flattened_nrows *= size;
|
||||
}
|
||||
return flattened_nrows;
|
||||
}();
|
||||
|
||||
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
|
||||
if (weight.layout() == c10::kStrided) {
|
||||
return at::addmm(bias, input_flattened, weight.t());
|
||||
} else {
|
||||
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
|
||||
// so we transpose the problem.
|
||||
// NOTE: at::matmul handles (dense @ sparse) similarly.
|
||||
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
|
||||
return at::addmm(bias_t, weight, input_flattened.t()).t();
|
||||
const auto input_sizes = input.sym_sizes();
|
||||
// can't use -1 in reshape because it errors when a dimension is 0
|
||||
c10::SymInt flattened_dim = 1;
|
||||
for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
|
||||
flattened_dim = flattened_dim * input_sizes[i];
|
||||
}
|
||||
}();
|
||||
|
||||
// Unflatten flattened row dims
|
||||
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()};
|
||||
result_sizes.back() = result_flattened.sym_size(1);
|
||||
return result_flattened.view_symint(result_sizes);
|
||||
auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
|
||||
const auto result = at::addmm(bias, inp_reshape, weight.t());
|
||||
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
|
||||
c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
|
||||
sizes_vec.push_back(result.sym_size(1));
|
||||
return result.view_symint(sizes_vec);
|
||||
}
|
||||
|
||||
|
||||
@ -107,23 +90,15 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
|
||||
// Fused op is marginally faster.
|
||||
return at::addmm(*bias, input, weight.t());
|
||||
}
|
||||
|
||||
const auto is_bias_likely_fusable = (
|
||||
bias->defined() &&
|
||||
// cuBLASLt: will fuse in the epilogue without copies
|
||||
// when input/weight/bias are all strided.
|
||||
// When weight is not strided, bias will not be fused,
|
||||
// but we can still dispatch here to avoid at::matmul
|
||||
// path which will probably use a very similar
|
||||
// flattening optimization.
|
||||
((bias->dim() == 1 || bias->squeeze().dim() == 1) && bias->is_contiguous_or_false())
|
||||
);
|
||||
if (is_bias_likely_fusable && !input.is_xla()) {
|
||||
// Also hit the fused path for contiguous nD input, if not using xla
|
||||
if (bias->defined() && !input.is_xla()) {
|
||||
// Also hit the fused path for contiguous 3D input, if not using xla
|
||||
// backend. Reshaping/flattening has some performance implications on xla.
|
||||
if (input.is_contiguous_or_false()) {
|
||||
bool is_contiguous = input.is_contiguous_or_false();
|
||||
if (is_contiguous && input_dim == 3) {
|
||||
return _flatten_nd_linear(input, weight, *bias);
|
||||
} else if (parseLinearFlatten3d()) {
|
||||
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
|
||||
return _flatten_nd_linear(input, weight, *bias);
|
||||
} else if (parseLinearFlatten3d() && input_dim == 3) {
|
||||
// If user forces flattening via env var
|
||||
const Tensor input_cont = input.contiguous();
|
||||
return _flatten_nd_linear(input_cont, weight, *bias);
|
||||
|
||||
@ -23,7 +23,6 @@
|
||||
#include <ATen/ops/_aminmax_native.h>
|
||||
#include <ATen/ops/_assert_async_native.h>
|
||||
#include <ATen/ops/_assert_scalar_native.h>
|
||||
#include <ATen/ops/_async_error_native.h>
|
||||
#include <ATen/ops/_functional_assert_async_native.h>
|
||||
#include <ATen/ops/_functional_assert_scalar_native.h>
|
||||
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
||||
@ -480,14 +479,6 @@ Tensor isfinite(const Tensor& self) {
|
||||
});
|
||||
}
|
||||
|
||||
void _async_error(std::string_view msg) {
|
||||
TORCH_CHECK(0, msg);
|
||||
}
|
||||
|
||||
void _async_error_meta(std::string_view msg) {
|
||||
// Do NOT error, it's an async error!
|
||||
}
|
||||
|
||||
void _assert_async_cpu(const Tensor& self) {
|
||||
TORCH_CHECK(
|
||||
native::is_nonzero(self),
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
@ -1711,37 +1710,11 @@ Tensor narrow_symint(
|
||||
"], but got ",
|
||||
start,
|
||||
")")
|
||||
|
||||
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
|
||||
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
|
||||
|
||||
if (cond1 || cond2) {
|
||||
if (cond1) {
|
||||
start = start + cur_size;
|
||||
}
|
||||
|
||||
TORCH_SYM_CHECK(
|
||||
start.sym_le(cur_size - length),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
length,
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
if (start < 0) {
|
||||
start = start + cur_size;
|
||||
}
|
||||
|
||||
// Unbacked start handling!
|
||||
|
||||
// Bounds check without converting start:
|
||||
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
|
||||
// length <= 0
|
||||
// - If start >= 0: need start + length <= cur_size
|
||||
auto end = start + length;
|
||||
TORCH_SYM_CHECK(
|
||||
(start.sym_lt(0).sym_and((end).sym_le(0)))
|
||||
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
|
||||
start.sym_le(cur_size - length),
|
||||
"start (",
|
||||
start,
|
||||
") + length (",
|
||||
@ -1749,28 +1722,7 @@ Tensor narrow_symint(
|
||||
") exceeds dimension size (",
|
||||
cur_size,
|
||||
").");
|
||||
|
||||
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
|
||||
return at::slice_symint(self, dim, start, end, 1);
|
||||
} else {
|
||||
// Cannot statically determine the condition due to unbacked.
|
||||
// This is an interesting situation; when start is negative and
|
||||
// start + length == 0, slice and narrow do different things.
|
||||
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
|
||||
// pass curr_size instead of 0. Otherwise, they would do the same thing.
|
||||
// This says at runtime: if start < 0 and end == 0, then pass curr_size
|
||||
// instead of 0.
|
||||
|
||||
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
|
||||
auto result =
|
||||
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
|
||||
|
||||
// Ensure slice allocated unbacked size is specialized to length.
|
||||
SymInt new_size = result.sym_size(dim);
|
||||
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
|
||||
|
||||
return result;
|
||||
}
|
||||
return at::slice_symint(self, dim, start, start + length, 1);
|
||||
}
|
||||
|
||||
// This overload exists purely for XLA, because they wanted to pass in
|
||||
@ -1784,8 +1736,8 @@ Tensor narrow_tensor_symint(
|
||||
start.dim() == 0 &&
|
||||
isIntegralType(start.scalar_type(), /*includeBool=*/false),
|
||||
"start must be an 0-dim integral Tensor.");
|
||||
c10::SymInt st = start.item().toSymInt();
|
||||
return at::narrow_symint(self, dim, std::move(st), std::move(length));
|
||||
int64_t st = start.item<int64_t>();
|
||||
return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length));
|
||||
}
|
||||
|
||||
std::
|
||||
|
||||
@ -293,7 +293,7 @@ struct ComputeLocationBase<scalar_t, /*align_corners=*/false> {
|
||||
, empty(size <= 0) {}
|
||||
|
||||
inline Vec unnormalize(const Vec &in) const {
|
||||
return (in + Vec(static_cast<scalar_t>(1))) * Vec(scaling_factor) - Vec(static_cast<scalar_t>(0.5));
|
||||
return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5);
|
||||
}
|
||||
|
||||
inline Vec clip_coordinates(const Vec &in) const {
|
||||
@ -831,7 +831,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
|
||||
|
||||
// constant used in cubic convolution
|
||||
// could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h
|
||||
const Vec A = Vec(static_cast<scalar_t>(-0.75));
|
||||
const Vec A = Vec(-0.75);
|
||||
|
||||
ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
|
||||
: inp_H(input.size(2))
|
||||
|
||||
@ -147,24 +147,14 @@ static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
||||
/*
|
||||
* Check whether for the given input we want to enable the Lt interface
|
||||
*/
|
||||
static bool isInputCompliesAddmmCudaLt(
|
||||
Tensor& result,
|
||||
const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Activation activation
|
||||
) {
|
||||
#ifdef USE_ROCM
|
||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
// Implies 2D bias which we currently not send through Lt.
|
||||
// TODO: this check is done pre col-major input preparation,
|
||||
// so, this condition can be ralexed in cases when a col-major
|
||||
// copy of result is needed.
|
||||
if (self.is_same(result) || self.dim() == 2) {
|
||||
if (result.is_same(self)) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
@ -179,33 +169,13 @@ static bool isInputCompliesAddmmCudaLt(
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
const auto scalar_type = mat1.scalar_type();
|
||||
return (beta.toComplexDouble() == 1.0
|
||||
// NOTE: row-major result is important when bias is 1D.
|
||||
// This is because Lt broadcasts 1D bias over the columns
|
||||
// while the aten::addmm API broadcasts it over the rows,
|
||||
// and this is in conjuction with the data preparation
|
||||
// procedure that does not transpose arguments with
|
||||
// col-major result. For col-major result we need
|
||||
// to explicitly transpose the problem so that bias is
|
||||
// correctly applied.
|
||||
// TODO: enable col-major result if needed.
|
||||
// TODO: no need to check result's layout when
|
||||
// !result.is_same(self) and self.dim() == 2, because
|
||||
// self needs to be copied into result and the bias ptr
|
||||
// will be ignored.
|
||||
&& result.dim() == 2 && result.is_contiguous()
|
||||
// Conditions for bias to be fusable
|
||||
&& (
|
||||
( // Conditions for bias to be fusable -- implies direct Lt path without copies.
|
||||
self.is_contiguous() &&
|
||||
// NOTE: fine to have 1-len dims to the left from the right-most one
|
||||
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
||||
self.sizes().back() == mat2_sizes[1]
|
||||
)
|
||||
|| ( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self),
|
||||
// and we need to copy self into result otherwise, so the self's layout becomes irrelevant.
|
||||
// See also TODO from above.
|
||||
activation != Activation::None && // Lt is faster when activation is fused
|
||||
(self.dim() == 2 && at::is_expandable_to(self.sizes(), {mat1_sizes[0], mat2_sizes[1]}))
|
||||
)
|
||||
self.is_contiguous() &&
|
||||
// NOTE: fine to have 1-len dims to the left from the right-most one
|
||||
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
||||
self.sizes().back() == mat2_sizes[1]
|
||||
)
|
||||
&& ( // some dtype restrictions
|
||||
#ifndef USE_ROCM
|
||||
@ -300,16 +270,7 @@ bool launchGemmAndBiasCublasLt(
|
||||
const Scalar& alpha,
|
||||
Activation activation = Activation::None
|
||||
) {
|
||||
// We apply bias in the epilogue only when it is 1D,
|
||||
// or when it can be squeezed to 1D.
|
||||
// self_ptr == nullptr implies ignore bias epilogue
|
||||
// and use standard gemm-like API.
|
||||
const auto* self_ptr = [&]() -> auto {
|
||||
if (self.dim() == 1 || self.squeeze().dim() == 1) {
|
||||
return self.const_data_ptr<scalar_t>();
|
||||
}
|
||||
return static_cast<const scalar_t*>(nullptr);
|
||||
}();
|
||||
const auto* self_ptr = self.const_data_ptr<scalar_t>();
|
||||
|
||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
@ -395,7 +356,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
@ -405,20 +366,19 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
if (!result.is_same(self)) {
|
||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||
|
||||
// We use bias ptr in the Lt path only when bias is 1D
|
||||
const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
|
||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||
if (!use_bias_ptr_lt) {
|
||||
// We do expand self even before
|
||||
if (disable_addmm_cuda_lt) {
|
||||
// When in non-Lt path we do expand self even before
|
||||
// check for beta != 0.0 to make sure that
|
||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||
// runs green.
|
||||
return expand_size(self, result.sizes(), "addmm");
|
||||
}
|
||||
// copy next, should broadcast
|
||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
}();
|
||||
// We do not copy bias only when we need the bias ptr
|
||||
if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) {
|
||||
// We copy bias when in the non-Lt path
|
||||
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
|
||||
// NOTE: self should broadcast over result
|
||||
at::native::copy_(result, *self_maybe_expanded);
|
||||
}
|
||||
|
||||
@ -884,69 +884,6 @@ struct type_specialized_kernel_launcher {
|
||||
}
|
||||
};
|
||||
|
||||
template <int arg_index>
|
||||
struct type_specialized_broadcast_kernel_launcher {
|
||||
template <
|
||||
typename func_t,
|
||||
typename array_t,
|
||||
typename dtypes_t,
|
||||
typename calc_t>
|
||||
static void apply(
|
||||
int64_t numel,
|
||||
func_t f,
|
||||
array_t data,
|
||||
dtypes_t dtypes,
|
||||
calc_t offset_calc) {
|
||||
using traits = function_traits<func_t>;
|
||||
using ret_t = typename traits::result_type;
|
||||
using arg0_t = typename traits::template arg<0>::type;
|
||||
using arg1_t = typename traits::template arg<1>::type;
|
||||
if (dtypes[0] == rt_binary_specializations[arg_index][0] &&
|
||||
dtypes[1] == rt_binary_specializations[arg_index][1] &&
|
||||
dtypes[2] == rt_binary_specializations[arg_index][2]) {
|
||||
using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0]>;
|
||||
using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1]>;
|
||||
using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2]>;
|
||||
constexpr int grp_sz = 128;
|
||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
||||
if (unrl) {
|
||||
auto offsets0 = offset_calc.get(idx);
|
||||
auto offsets1 = offset_calc.get(idx + grp_sz);
|
||||
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
|
||||
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
|
||||
void* out0 = data[0] + offsets0[0];
|
||||
void* out1 = data[0] + offsets1[0];
|
||||
void* out2 = data[0] + offsets2[0];
|
||||
void* out3 = data[0] + offsets3[0];
|
||||
auto u = c10::load<arg0_cpp_t>(data[1] + offsets0[1]);
|
||||
auto v = c10::load<arg1_cpp_t>(data[2] + offsets0[2]);
|
||||
ret_t result0 = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
|
||||
auto u1 = c10::load<arg0_cpp_t>(data[1] + offsets1[1]);
|
||||
auto v1 = c10::load<arg1_cpp_t>(data[2]+ offsets1[2]);
|
||||
ret_t result1 = f(c10::convert<arg0_t>(u1), c10::convert<arg1_t>(v1));
|
||||
auto u2 = c10::load<arg0_cpp_t>(data[1] + offsets2[1]);
|
||||
auto v2 = c10::load<arg1_cpp_t>(data[2] + offsets2[2]);
|
||||
ret_t result2 = f(c10::convert<arg0_t>(u2), c10::convert<arg1_t>(v2));
|
||||
auto u3 = c10::load<arg0_cpp_t>(data[1] + offsets3[1]);
|
||||
auto v3 = c10::load<arg1_cpp_t>(data[2] + offsets3[2]);
|
||||
ret_t result3 = f(c10::convert<arg0_t>(u3), c10::convert<arg1_t>(v3));
|
||||
*(ret_cpp_t*)out0 = c10::convert<ret_cpp_t>(result0);
|
||||
*(ret_cpp_t*)out1 = c10::convert<ret_cpp_t>(result1);
|
||||
*(ret_cpp_t*)out2 = c10::convert<ret_cpp_t>(result2);
|
||||
*(ret_cpp_t*)out3 = c10::convert<ret_cpp_t>(result3);
|
||||
} else {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
void* out = data[0] + offsets[0];
|
||||
auto u = c10::load<arg0_cpp_t>(data[1] + offsets[1]);
|
||||
auto v = c10::load<arg1_cpp_t>(data[2] + offsets[2]);
|
||||
ret_t result = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
|
||||
*(ret_cpp_t*)out = c10::convert<ret_cpp_t>(result);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
#endif
|
||||
|
||||
@ -1065,32 +1002,6 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
|
||||
}
|
||||
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
||||
#ifdef USE_ROCM
|
||||
if (check_binary_rt_types_for_specialization(iter)) {
|
||||
// constexpr to reduce the amount of kernels generated for
|
||||
// broadcast elementwise with mexed dtypes and limit which functors are actually
|
||||
// applied to the load and store at compile time.
|
||||
using func_tuple = typename traits::ArgsTuple;
|
||||
if constexpr (
|
||||
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
|
||||
check_binary_functor_types_for_specialization<
|
||||
func_tuple,
|
||||
float,
|
||||
float,
|
||||
traits::arity,
|
||||
/*arg_num=*/0>::check()) {
|
||||
memory::detail::static_unroll<
|
||||
type_specialized_broadcast_kernel_launcher,
|
||||
rt_binary_specializations.size()>::with_args(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
dtypes,
|
||||
offset_calc
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int grp_sz = 128;
|
||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
||||
if (unrl) {
|
||||
|
||||
@ -22,9 +22,6 @@
|
||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||
#include <ATen/native/cuda/GroupMM.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <ATen/native/hip/ck_group_gemm.h>
|
||||
#endif
|
||||
#include <ATen/ceil_div.h>
|
||||
|
||||
#ifdef USE_FBGEMM_GENAI
|
||||
@ -669,19 +666,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||
bool use_fast_path = false;
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
||||
use_fast_path = true;
|
||||
}
|
||||
#endif
|
||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
if (use_fast_path) {
|
||||
// fast path, no d2h sync needed
|
||||
#ifndef USE_ROCM
|
||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||
#else
|
||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
||||
#endif
|
||||
} else {
|
||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <array>
|
||||
#include <type_traits>
|
||||
#include <ATen/core/TensorBase.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
@ -73,6 +74,7 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
|
||||
|
||||
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
|
||||
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
|
||||
|
||||
if (is_gather_like && num_indices==1) {
|
||||
const size_t element_size = iter.element_size(0);
|
||||
constexpr size_t alignment = 16;
|
||||
@ -82,9 +84,16 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
|
||||
auto ind_dim_size = index_size[0];
|
||||
auto inp_stride_bytes = index_stride[0];
|
||||
auto out_stride_bytes = iter.strides(0)[1];
|
||||
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
||||
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
||||
return;
|
||||
// avoid grid overflow in the fast kernel
|
||||
const int64_t vec_chunks = ceil_div(slice_size, alignment);
|
||||
const int64_t blocks_per_slice_upper = ceil_div(vec_chunks, (int64_t)launch_size_nd);
|
||||
const int max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
// if it's an eligible grid we use the fast path, otherwise default to slower path
|
||||
if (blocks_per_slice_upper <= max_grid_y) {
|
||||
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
||||
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -13,12 +13,11 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
|
||||
if (allow_neg_indices) {
|
||||
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
||||
}
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
|
||||
// off is guaranteed to be within int32 limits
|
||||
for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) {
|
||||
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
|
||||
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
|
||||
}
|
||||
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind);
|
||||
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
|
||||
if (off >= slice_size) return;
|
||||
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
|
||||
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
|
||||
}
|
||||
|
||||
|
||||
@ -31,9 +30,7 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int
|
||||
auto num_threads = at::round_up(
|
||||
at::ceil_div(slice_size_in_bytes, Alignment),
|
||||
static_cast<int64_t>(C10_WARP_SIZE));
|
||||
uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
grid_y = std::min(static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y);
|
||||
dim3 grid = {static_cast<uint32_t>(num_ind), grid_y, 1};
|
||||
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
|
||||
auto block = std::min(max_num_threads, num_threads);
|
||||
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
|
||||
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <optional>
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::Tensor& out);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
@ -1,462 +0,0 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
#include <ATen/hip/HIPContext.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/TensorAccessor.h>
|
||||
#include <c10/hip/HIPStream.h>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ck/ck.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/utility/tuple.hpp>
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
namespace at {
|
||||
namespace hip {
|
||||
namespace detail {
|
||||
|
||||
namespace CkTypes {
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
|
||||
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
|
||||
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
|
||||
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
||||
3, 8, 8, 1,
|
||||
1, 1,
|
||||
S<1,32,1,8>, 4
|
||||
>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename DataType>
|
||||
void launch_grouped_bgemm_ck_impl_dispatch(
|
||||
const at::Tensor& mat_a,
|
||||
const at::Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
at::Tensor& out)
|
||||
{
|
||||
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
|
||||
using PassThrough = CkTypes::PassThrough;
|
||||
|
||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
||||
std::vector<const void*> p_a_ptrs, p_b_ptrs;
|
||||
std::vector<void*> p_e_ptrs;
|
||||
// Note: d_ptrs will be resized after we populate the other vectors
|
||||
|
||||
const int mat_a_dim = mat_a.dim();
|
||||
const int mat_b_dim = mat_b.dim();
|
||||
|
||||
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
|
||||
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
|
||||
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
|
||||
const size_t a_element_size = mat_a.element_size();
|
||||
const size_t b_element_size = mat_b.element_size();
|
||||
const size_t out_element_size = out.element_size();
|
||||
|
||||
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
|
||||
if (mat_a_dim == 2 && mat_b_dim == 2) {
|
||||
// 2D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
const int M = mat_a.size(0); // number of rows in A
|
||||
const int N = mat_b.size(1); // number of columns in B
|
||||
const int K = mat_a.size(1); // columns in A == rows in B
|
||||
// for 2d*2d input, output is 3d.
|
||||
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
|
||||
int end_k = offs_accessor[i];
|
||||
int k = end_k - start_k;
|
||||
|
||||
//K dimension are sliced, hence select stride(1) always.
|
||||
//K dimension is always dimension 1, regardless of memory layout (row/column major)
|
||||
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
|
||||
// Leading dimension = distance between rows = stride(0)
|
||||
ldb = mat_b.stride(0);
|
||||
} else {
|
||||
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
|
||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
|
||||
// Leading dimension = distance between columns = stride(1)
|
||||
ldb = mat_b.stride(1);
|
||||
}
|
||||
|
||||
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
|
||||
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
int lda, ldc;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
}
|
||||
// Output is always row-major in 3D tensor [num_groups, M, N]
|
||||
// Leading dimension for each group's [M,N] slice = stride(1) = N
|
||||
ldc = out.stride(1);
|
||||
size_t output_group_bytes = M * N * out_element_size;
|
||||
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(k),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
|
||||
// 2D*3D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
|
||||
// 2d*3d input, output is 2d.
|
||||
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
|
||||
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
|
||||
const int K = mat_a.size(1); // columns in A
|
||||
// For 2D-3D case: The output determines N (result width)
|
||||
const int N = out.size(1); // N is the width of the output tensor
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_m = offs_accessor[i];
|
||||
int m = end_m - start_m;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (m <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A rows for group i: skip start_m rows
|
||||
const void* group_a_ptr;
|
||||
int lda;
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
lda = mat_a.stride(0); // distance between rows
|
||||
} else {
|
||||
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
|
||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Detect stride pattern for A tensor to determine appropriate lda calculation
|
||||
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
|
||||
|
||||
if (a_is_strided_tensor) {
|
||||
// For strided A tensors: stride(0) gives the actual leading dimension
|
||||
lda = mat_a.stride(0);
|
||||
} else {
|
||||
// For non-strided A tensors: use the M dimension (total rows)
|
||||
lda = mat_a.size(0); // Total M dimension for column-major layout
|
||||
}
|
||||
}
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
int ldb;
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
|
||||
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
|
||||
} else {
|
||||
// Detect stride pattern to determine appropriate ldb calculation
|
||||
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
|
||||
|
||||
if (is_strided_tensor) {
|
||||
// For strided tensors: stride(2) gives the actual leading dimension
|
||||
ldb = mat_b.stride(2);
|
||||
} else {
|
||||
// For non-strided tensors: use the N dimension
|
||||
ldb = mat_b.size(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
|
||||
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
|
||||
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(m),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
|
||||
// 3d*3d input, output is 3d - batched matrix multiplication
|
||||
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
|
||||
// Each batch is processed as a separate GEMM operation
|
||||
const int batch_size = mat_a.size(0);
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
|
||||
|
||||
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
|
||||
int N;
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
N = mat_b.size(2);
|
||||
} else if (mat_b.size(2) == K) {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
N = mat_b.size(1);
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
|
||||
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B batch for group i: B[i, :, :]
|
||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
||||
|
||||
// Select output batch for group i: Output[i, :, :]
|
||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
||||
|
||||
int lda, ldb, ldc;
|
||||
|
||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
} else {
|
||||
// Column-major A: leading dimension = distance between columns = stride(2)
|
||||
lda = mat_a.stride(2);
|
||||
}
|
||||
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B: leading dimension = distance between rows
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(1); // stride between K rows
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
|
||||
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
|
||||
}
|
||||
} else {
|
||||
// Column-major B: leading dimension = distance between columns
|
||||
if (mat_b.size(1) == K) {
|
||||
// B is [batch, k, n] - normal layout
|
||||
ldb = mat_b.stride(2); // stride between N columns
|
||||
} else {
|
||||
// B is [batch, n, k] - transposed layout
|
||||
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
|
||||
}
|
||||
}
|
||||
|
||||
// Output is typically row-major: leading dimension = distance between rows = stride(1)
|
||||
ldc = out.stride(1);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(N),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
|
||||
// 3D*2D case requires offset tensor
|
||||
auto offs_accessor = offs->accessor<int, 1>();
|
||||
int num_groups = offs_accessor.size(0);
|
||||
// 3d*2d input, output is 3d.
|
||||
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
|
||||
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
|
||||
const int batch_size = mat_a.size(0); // n_groups
|
||||
const int M = mat_a.size(1); // rows in each A matrix
|
||||
const int K = mat_a.size(2); // columns in A
|
||||
|
||||
// For row-major A and B case: B should be [K, total_N]
|
||||
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
|
||||
|
||||
for (int i = 0; i < num_groups; ++i) {
|
||||
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
|
||||
int end_n = offs_accessor[i];
|
||||
int n = end_n - start_n;
|
||||
|
||||
// Skip zero-sized groups but continue processing subsequent groups
|
||||
if (n <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Select A batch for group i: A[i, :, :]
|
||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
||||
|
||||
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
|
||||
const void* group_b_ptr;
|
||||
int ldb;
|
||||
|
||||
// Check if B is row-major or column-major
|
||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
||||
// Row-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(0); // distance between rows (should be total_N)
|
||||
} else {
|
||||
// Column-major B [K, total_N]: slice columns [start_n:end_n]
|
||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
||||
ldb = mat_b.stride(1); // distance between columns (should be K)
|
||||
}
|
||||
|
||||
// Select output slice for group i: Output[:, start_n:end_n]
|
||||
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
|
||||
|
||||
int lda, ldc;
|
||||
|
||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
||||
lda = mat_a.stride(1);
|
||||
// Output is row-major: leading dimension = distance between rows = stride(0)
|
||||
ldc = out.stride(0);
|
||||
|
||||
gemm_descs.push_back({
|
||||
static_cast<ck::index_t>(M),
|
||||
static_cast<ck::index_t>(n),
|
||||
static_cast<ck::index_t>(K),
|
||||
static_cast<ck::index_t>(lda),
|
||||
static_cast<ck::index_t>(ldb),
|
||||
static_cast<ck::index_t>(ldc),
|
||||
{} // --> stride_Ds_
|
||||
});
|
||||
p_a_ptrs.push_back(group_a_ptr);
|
||||
p_b_ptrs.push_back(group_b_ptr);
|
||||
p_e_ptrs.push_back(group_e_ptr);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
|
||||
}
|
||||
|
||||
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
|
||||
|
||||
// Initialize d_ptrs with the correct size
|
||||
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
|
||||
|
||||
static DeviceOp gemm_instance;
|
||||
auto argument = gemm_instance.MakeArgument(
|
||||
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
|
||||
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
|
||||
);
|
||||
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
|
||||
"CK Group GEMM: argument unsupported (shape/strides/type config)");
|
||||
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
|
||||
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
|
||||
|
||||
void* gemm_arg_buf = nullptr;
|
||||
void* ws_buf = nullptr;
|
||||
|
||||
hipMalloc(&gemm_arg_buf, arg_buf_size);
|
||||
hipMalloc(&ws_buf, ws_size);
|
||||
|
||||
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
|
||||
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
|
||||
|
||||
auto invoker = gemm_instance.MakeInvoker();
|
||||
hipStream_t stream = c10::hip::getCurrentHIPStream();
|
||||
invoker.Run(argument, {stream});
|
||||
hipFree(gemm_arg_buf);
|
||||
hipFree(ws_buf);
|
||||
}
|
||||
|
||||
void group_gemm_ck(
|
||||
const at::Tensor& input_a,
|
||||
const at::Tensor& input_b_colmajor,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& /*bias*/,
|
||||
at::Tensor& out)
|
||||
{
|
||||
// Detect if input_a is row-major based on stride pattern
|
||||
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
|
||||
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
|
||||
// Ensure tensor A is row-major and contiguous if not already
|
||||
at::Tensor mat_a = input_a;
|
||||
if (!a_row_major) {
|
||||
// If A is not row-major, make it contiguous (row-major)
|
||||
mat_a = input_a.contiguous();
|
||||
}
|
||||
// Force tensor B to be column-major using double transpose trick
|
||||
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
|
||||
at::Tensor mat_b = input_b_colmajor;
|
||||
if (!b_col_major) {
|
||||
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
|
||||
}
|
||||
|
||||
// For 3D tensors, check the last dimension stride for row-major detection
|
||||
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
|
||||
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
|
||||
|
||||
if (mat_a.dtype() == at::kBFloat16) {
|
||||
// bf16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kHalf) {
|
||||
// fp16 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else if (mat_a.dtype() == at::kFloat) {
|
||||
// fp32 path
|
||||
if (a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (a_row_major && !b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else if (!a_row_major && b_row_major) {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
} else {
|
||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace hip
|
||||
} // namespace at
|
||||
@ -133,7 +133,7 @@ at::Tensor quantized_convolution(
|
||||
// supported in conv.
|
||||
mask_weight = weight_zero_points.numel() > 1 ? 1 : 0;
|
||||
if (groups > 1 && weight_zero_points.numel() > 1)
|
||||
mask_weight = (1 << 0) | (1 << 1); // 2^0 (group) | 2^1 (output channel)
|
||||
mask_weight = (2 ^ 0) | (2 ^ 1); // 2^0 (group) | 2^1 (output channel)
|
||||
dnnl::primitive_attr pattr;
|
||||
|
||||
bool src_need_zp = (act_zero_point != 0);
|
||||
|
||||
@ -40,6 +40,8 @@ using namespace at::mps;
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||
|
||||
struct MPSScalar {
|
||||
id<MTLBuffer> getMTLBuffer() const {
|
||||
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
|
||||
|
||||
@ -53,6 +53,21 @@
|
||||
@end
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
|
||||
__block std::optional<std::exception_ptr> block_exception;
|
||||
dispatch_sync(queue, ^() {
|
||||
try {
|
||||
block();
|
||||
} catch (...) {
|
||||
block_exception = std::current_exception();
|
||||
}
|
||||
});
|
||||
if (block_exception) {
|
||||
std::rethrow_exception(*block_exception);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes distance from lowest to highest element offset in given tensor.
|
||||
*/
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user