mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-18 17:45:09 +08:00
Compare commits
3 Commits
ciflow/tru
...
csl/lintru
| Author | SHA1 | Date | |
|---|---|---|---|
| bdc1181606 | |||
| f574505205 | |||
| 4dd22ac43c |
@ -13,4 +13,3 @@ exclude:
|
|||||||
- "**/benchmarks/**"
|
- "**/benchmarks/**"
|
||||||
- "**/test_*.py"
|
- "**/test_*.py"
|
||||||
- "**/*_test.py"
|
- "**/*_test.py"
|
||||||
- "tools/**"
|
|
||||||
|
|||||||
@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
|
|||||||
ENV LANG en_US.UTF-8
|
ENV LANG en_US.UTF-8
|
||||||
ENV LANGUAGE en_US.UTF-8
|
ENV LANGUAGE en_US.UTF-8
|
||||||
|
|
||||||
ARG DEVTOOLSET_VERSION=13
|
ARG DEVTOOLSET_VERSION=11
|
||||||
|
|
||||||
RUN yum -y update
|
RUN yum -y update
|
||||||
RUN yum -y install epel-release
|
RUN yum -y install epel-release
|
||||||
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
||||||
RUN yum -y install glibc-langpack-en
|
RUN yum -y install glibc-langpack-en
|
||||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
|
||||||
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
||||||
RUN git config --global --add safe.directory '*'
|
RUN git config --global --add safe.directory '*'
|
||||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||||
@ -41,7 +41,6 @@ RUN bash ./install_conda.sh && rm install_conda.sh
|
|||||||
# Install CUDA
|
# Install CUDA
|
||||||
FROM base as cuda
|
FROM base as cuda
|
||||||
ARG CUDA_VERSION=12.6
|
ARG CUDA_VERSION=12.6
|
||||||
ARG DEVTOOLSET_VERSION=13
|
|
||||||
RUN rm -rf /usr/local/cuda-*
|
RUN rm -rf /usr/local/cuda-*
|
||||||
ADD ./common/install_cuda.sh install_cuda.sh
|
ADD ./common/install_cuda.sh install_cuda.sh
|
||||||
COPY ./common/install_nccl.sh install_nccl.sh
|
COPY ./common/install_nccl.sh install_nccl.sh
|
||||||
@ -51,8 +50,7 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
|
|||||||
# Preserve CUDA_VERSION for the builds
|
# Preserve CUDA_VERSION for the builds
|
||||||
ENV CUDA_VERSION=${CUDA_VERSION}
|
ENV CUDA_VERSION=${CUDA_VERSION}
|
||||||
# Make things in our path by default
|
# Make things in our path by default
|
||||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
|
||||||
|
|
||||||
|
|
||||||
FROM cuda as cuda12.6
|
FROM cuda as cuda12.6
|
||||||
RUN bash ./install_cuda.sh 12.6
|
RUN bash ./install_cuda.sh 12.6
|
||||||
@ -70,22 +68,8 @@ FROM cuda as cuda13.0
|
|||||||
RUN bash ./install_cuda.sh 13.0
|
RUN bash ./install_cuda.sh 13.0
|
||||||
ENV DESIRED_CUDA=13.0
|
ENV DESIRED_CUDA=13.0
|
||||||
|
|
||||||
FROM ${ROCM_IMAGE} as rocm_base
|
FROM ${ROCM_IMAGE} as rocm
|
||||||
ARG DEVTOOLSET_VERSION=13
|
|
||||||
ENV LC_ALL en_US.UTF-8
|
|
||||||
ENV LANG en_US.UTF-8
|
|
||||||
ENV LANGUAGE en_US.UTF-8
|
|
||||||
# Install devtoolset on ROCm base image
|
|
||||||
RUN yum -y update && \
|
|
||||||
yum -y install epel-release && \
|
|
||||||
yum -y install glibc-langpack-en && \
|
|
||||||
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
|
||||||
RUN git config --global --add safe.directory '*'
|
|
||||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
|
||||||
|
|
||||||
FROM rocm_base as rocm
|
|
||||||
ARG PYTORCH_ROCM_ARCH
|
ARG PYTORCH_ROCM_ARCH
|
||||||
ARG DEVTOOLSET_VERSION=13
|
|
||||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||||
ADD ./common/install_mkl.sh install_mkl.sh
|
ADD ./common/install_mkl.sh install_mkl.sh
|
||||||
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
||||||
@ -104,7 +88,6 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
|
|||||||
|
|
||||||
# Final step
|
# Final step
|
||||||
FROM ${BASE_TARGET} as final
|
FROM ${BASE_TARGET} as final
|
||||||
ARG DEVTOOLSET_VERSION=13
|
|
||||||
COPY --from=openssl /opt/openssl /opt/openssl
|
COPY --from=openssl /opt/openssl /opt/openssl
|
||||||
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
||||||
COPY --from=conda /opt/conda /opt/conda
|
COPY --from=conda /opt/conda /opt/conda
|
||||||
|
|||||||
@ -63,7 +63,7 @@ docker build \
|
|||||||
--target final \
|
--target final \
|
||||||
--progress plain \
|
--progress plain \
|
||||||
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
||||||
--build-arg "DEVTOOLSET_VERSION=13" \
|
--build-arg "DEVTOOLSET_VERSION=11" \
|
||||||
${EXTRA_BUILD_ARGS} \
|
${EXTRA_BUILD_ARGS} \
|
||||||
-t ${tmp_tag} \
|
-t ${tmp_tag} \
|
||||||
$@ \
|
$@ \
|
||||||
|
|||||||
@ -195,16 +195,13 @@ case "$tag" in
|
|||||||
NINJA_VERSION=1.9.0
|
NINJA_VERSION=1.9.0
|
||||||
TRITON=yes
|
TRITON=yes
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
|
pytorch-linux-jammy-xpu-n-py3)
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
ANACONDA_PYTHON_VERSION=3.10
|
||||||
GCC_VERSION=11
|
GCC_VERSION=11
|
||||||
VISION=yes
|
VISION=yes
|
||||||
XPU_VERSION=2025.2
|
XPU_VERSION=2025.2
|
||||||
NINJA_VERSION=1.9.0
|
NINJA_VERSION=1.9.0
|
||||||
TRITON=yes
|
TRITON=yes
|
||||||
if [[ $tag =~ "benchmarks" ]]; then
|
|
||||||
INDUCTOR_BENCHMARKS=yes
|
|
||||||
fi
|
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks)
|
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks)
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
ANACONDA_PYTHON_VERSION=3.10
|
||||||
@ -261,9 +258,9 @@ case "$tag" in
|
|||||||
PYTHON_VERSION=3.10
|
PYTHON_VERSION=3.10
|
||||||
CUDA_VERSION=12.8.1
|
CUDA_VERSION=12.8.1
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-aarch64-py3.10-gcc13)
|
pytorch-linux-jammy-aarch64-py3.10-gcc11)
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
ANACONDA_PYTHON_VERSION=3.10
|
||||||
GCC_VERSION=13
|
GCC_VERSION=11
|
||||||
ACL=yes
|
ACL=yes
|
||||||
VISION=yes
|
VISION=yes
|
||||||
OPENBLAS=yes
|
OPENBLAS=yes
|
||||||
@ -271,19 +268,9 @@ case "$tag" in
|
|||||||
# from pytorch/llvm:9.0.1 is x86 specific
|
# from pytorch/llvm:9.0.1 is x86 specific
|
||||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-aarch64-py3.10-clang21)
|
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
ANACONDA_PYTHON_VERSION=3.10
|
||||||
CLANG_VERSION=21
|
GCC_VERSION=11
|
||||||
ACL=yes
|
|
||||||
VISION=yes
|
|
||||||
OPENBLAS=yes
|
|
||||||
# snadampal: skipping llvm src build install because the current version
|
|
||||||
# from pytorch/llvm:9.0.1 is x86 specific
|
|
||||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
|
||||||
;;
|
|
||||||
pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks)
|
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
|
||||||
GCC_VERSION=13
|
|
||||||
ACL=yes
|
ACL=yes
|
||||||
VISION=yes
|
VISION=yes
|
||||||
OPENBLAS=yes
|
OPENBLAS=yes
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
|
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
set -eux
|
set -eux
|
||||||
|
|
||||||
ACL_VERSION=${ACL_VERSION:-"v52.6.0"}
|
ACL_VERSION=${ACL_VERSION:-"v25.02"}
|
||||||
ACL_INSTALL_DIR="/acl"
|
ACL_INSTALL_DIR="/acl"
|
||||||
|
|
||||||
# Clone ACL
|
# Clone ACL
|
||||||
|
|||||||
@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
|
|||||||
# work around ubuntu apt-get conflicts
|
# work around ubuntu apt-get conflicts
|
||||||
sudo apt-get -y -f install
|
sudo apt-get -y -f install
|
||||||
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
||||||
if [[ $CLANG_VERSION -ge 18 ]]; then
|
if [[ $CLANG_VERSION == 18 ]]; then
|
||||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
|
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then
|
|||||||
# Need the official toolchain repo to get alternate packages
|
# Need the official toolchain repo to get alternate packages
|
||||||
add-apt-repository ppa:ubuntu-toolchain-r/test
|
add-apt-repository ppa:ubuntu-toolchain-r/test
|
||||||
apt-get update
|
apt-get update
|
||||||
apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION
|
apt-get install -y g++-$GCC_VERSION
|
||||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
|
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
|
||||||
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
|
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
|
||||||
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
|
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
|
||||||
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50
|
|
||||||
|
|
||||||
# Cleanup package manager
|
# Cleanup package manager
|
||||||
apt-get autoclean && apt-get clean
|
apt-get autoclean && apt-get clean
|
||||||
|
|||||||
@ -1,56 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# Script used only in CD pipeline
|
|
||||||
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
# install dependencies
|
|
||||||
dnf -y install gmp-devel libmpc-devel texinfo flex bison
|
|
||||||
|
|
||||||
cd /usr/local/src
|
|
||||||
# fetch source for gcc 13
|
|
||||||
git clone --depth 1 --single-branch -b releases/gcc-13.3.0 https://github.com/gcc-mirror/gcc.git gcc-13.3.0
|
|
||||||
|
|
||||||
mkdir -p gcc-13.3.0/build-gomp
|
|
||||||
cd gcc-13.3.0/build-gomp
|
|
||||||
|
|
||||||
# configure gcc build
|
|
||||||
# I got these flags by:
|
|
||||||
# 1. downloading the source rpm for gcc-11 on AlmaLinux 8 container
|
|
||||||
# dnf install -y dnf-plugins-core rpmdevtools
|
|
||||||
# dnf download --source libgomp
|
|
||||||
# 2. extracting the gcc.spec from the source.
|
|
||||||
# rpmdev-extract gcc-xx.src.rpm
|
|
||||||
# 3. extracting optflags and ld_flags from gcc.spec:
|
|
||||||
# rpm --eval '%{optflags}'
|
|
||||||
# rpm --eval '%{build_ldflags}'
|
|
||||||
#
|
|
||||||
# I had to remove the following flags because they didn't compile for this version of libgomp:
|
|
||||||
# -Werror=format-security
|
|
||||||
# -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1
|
|
||||||
# -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1
|
|
||||||
#
|
|
||||||
# I added -march=armv8-a -mtune=generic to make them explicit. I don't think they're strictly needed.
|
|
||||||
|
|
||||||
OPT_FLAGS='-O2 -march=armv8-a -mtune=generic'\
|
|
||||||
' -fexceptions -g -grecord-gcc-switches -pipe -Wall'\
|
|
||||||
' -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS'\
|
|
||||||
' -fstack-protector-strong -fasynchronous-unwind-tables'\
|
|
||||||
' -fstack-clash-protection'
|
|
||||||
|
|
||||||
LDFLAGS='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now'
|
|
||||||
|
|
||||||
CFLAGS="$OPT_FLAGS" \
|
|
||||||
CXXFLAGS="$OPT_FLAGS" \
|
|
||||||
LDFLAGS="$LDFLAGS" \
|
|
||||||
../configure \
|
|
||||||
--prefix=/usr \
|
|
||||||
--libdir=/usr/lib64 \
|
|
||||||
--enable-languages=c,c++ \
|
|
||||||
--disable-multilib \
|
|
||||||
--disable-bootstrap \
|
|
||||||
--enable-libgomp
|
|
||||||
|
|
||||||
# only build libgomp
|
|
||||||
make -j$(nproc) all-target-libgomp
|
|
||||||
|
|
||||||
make install-target-libgomp
|
|
||||||
@ -10,7 +10,6 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
|
|||||||
|
|
||||||
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
||||||
OPENBLAS_BUILD_FLAGS="
|
OPENBLAS_BUILD_FLAGS="
|
||||||
CC=gcc
|
|
||||||
NUM_THREADS=128
|
NUM_THREADS=128
|
||||||
USE_OPENMP=1
|
USE_OPENMP=1
|
||||||
NO_SHARED=0
|
NO_SHARED=0
|
||||||
|
|||||||
@ -12,8 +12,8 @@ function do_install() {
|
|||||||
|
|
||||||
rocm_version_nodot=${rocm_version//./}
|
rocm_version_nodot=${rocm_version//./}
|
||||||
|
|
||||||
# post merge of https://github.com/icl-utk-edu/magma/pull/65
|
# https://github.com/icl-utk-edu/magma/pull/65
|
||||||
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
|
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
|
||||||
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
|
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
|
||||||
|
|
||||||
rocm_dir="/opt/rocm"
|
rocm_dir="/opt/rocm"
|
||||||
|
|||||||
@ -149,7 +149,7 @@ FROM cpu_final as rocm_final
|
|||||||
ARG ROCM_VERSION=6.0
|
ARG ROCM_VERSION=6.0
|
||||||
ARG PYTORCH_ROCM_ARCH
|
ARG PYTORCH_ROCM_ARCH
|
||||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||||
ARG DEVTOOLSET_VERSION=13
|
ARG DEVTOOLSET_VERSION=11
|
||||||
ENV LDFLAGS="-Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64 -Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib"
|
ENV LDFLAGS="-Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64 -Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib"
|
||||||
# Somewhere in ROCm stack, we still use non-existing /opt/rocm/hip path,
|
# Somewhere in ROCm stack, we still use non-existing /opt/rocm/hip path,
|
||||||
# below workaround helps avoid error
|
# below workaround helps avoid error
|
||||||
|
|||||||
@ -50,10 +50,6 @@ RUN rm install_ninja.sh
|
|||||||
ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||||
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
|
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
# Build a newer version of libgomp than that supported in in Almalinux 8.
|
|
||||||
COPY ./common/install_libgomp.sh install_libgomp.sh
|
|
||||||
RUN bash ./install_libgomp.sh && rm install_libgomp.sh
|
|
||||||
|
|
||||||
# git236+ would refuse to run git commands in repos owned by other users
|
# git236+ would refuse to run git commands in repos owned by other users
|
||||||
# Which causes version check to fail, as pytorch repo is bind-mounted into the image
|
# Which causes version check to fail, as pytorch repo is bind-mounted into the image
|
||||||
# Override this behaviour by treating every folder as safe
|
# Override this behaviour by treating every folder as safe
|
||||||
|
|||||||
@ -97,7 +97,7 @@ case ${image} in
|
|||||||
manylinux2_28-builder:xpu)
|
manylinux2_28-builder:xpu)
|
||||||
TARGET=xpu_final
|
TARGET=xpu_final
|
||||||
GPU_IMAGE=amd64/almalinux:8
|
GPU_IMAGE=amd64/almalinux:8
|
||||||
DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13"
|
DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=11"
|
||||||
MANY_LINUX_VERSION="2_28"
|
MANY_LINUX_VERSION="2_28"
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
|
|||||||
@ -1,11 +1,15 @@
|
|||||||
sphinx==7.2.6
|
sphinx==5.3.0
|
||||||
#Description: This is used to generate PyTorch docs
|
#Description: This is used to generate PyTorch docs
|
||||||
#Pinned versions: 7.2.6
|
#Pinned versions: 5.3.0
|
||||||
|
|
||||||
pytorch_sphinx_theme2==0.2.0
|
standard-imghdr==3.13.0; python_version >= "3.13"
|
||||||
#Description: This is needed to generate PyTorch docs
|
#Description: This is needed by Sphinx, so it needs to be added here.
|
||||||
#Pinned versions: 0.2.0
|
# The reasons are as follows:
|
||||||
|
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
|
||||||
|
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
|
||||||
|
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
|
||||||
|
|
||||||
|
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
|
||||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||||
# something related to Docker setup. We can investigate this later.
|
# something related to Docker setup. We can investigate this later.
|
||||||
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
|
|||||||
#Description: This is used to generate PyTorch docs
|
#Description: This is used to generate PyTorch docs
|
||||||
#Pinned versions: 2.13.0
|
#Pinned versions: 2.13.0
|
||||||
|
|
||||||
breathe==4.36.0
|
breathe==4.34.0
|
||||||
#Description: This is used to generate PyTorch C++ docs
|
#Description: This is used to generate PyTorch C++ docs
|
||||||
#Pinned versions: 4.36.0
|
#Pinned versions: 4.34.0
|
||||||
|
|
||||||
exhale==0.3.7
|
exhale==0.2.3
|
||||||
#Description: This is used to generate PyTorch C++ docs
|
#Description: This is used to generate PyTorch C++ docs
|
||||||
#Pinned versions: 0.3.7
|
#Pinned versions: 0.2.3
|
||||||
|
|
||||||
docutils==0.20
|
docutils==0.16
|
||||||
#Description: This is used to generate PyTorch C++ docs
|
#Description: This is used to generate PyTorch C++ docs
|
||||||
#Pinned versions: 0.20
|
#Pinned versions: 0.16
|
||||||
|
|
||||||
bs4==0.0.1
|
bs4==0.0.1
|
||||||
#Description: This is used to generate PyTorch C++ docs
|
#Description: This is used to generate PyTorch C++ docs
|
||||||
@ -52,13 +56,13 @@ IPython==8.12.0
|
|||||||
#Description: This is used to generate PyTorch functorch docs
|
#Description: This is used to generate PyTorch functorch docs
|
||||||
#Pinned versions: 8.12.0
|
#Pinned versions: 8.12.0
|
||||||
|
|
||||||
myst-nb==1.3.0
|
myst-nb==0.17.2
|
||||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||||
#Pinned versions: 1.3.0
|
#Pinned versions: 0.17.2
|
||||||
|
|
||||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||||
python-etcd==0.4.5
|
python-etcd==0.4.5
|
||||||
sphinx-copybutton==0.5.0
|
sphinx-copybutton==0.5.0
|
||||||
sphinx-design==0.6.1
|
sphinx-design==0.4.0
|
||||||
sphinxcontrib-mermaid==1.0.0
|
sphinxcontrib-mermaid==1.0.0
|
||||||
myst-parser==4.0.1
|
myst-parser==0.18.1
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
3.5.1
|
3.5.0
|
||||||
|
|||||||
@ -54,15 +54,12 @@ ENV OPENSSL_DIR /opt/openssl
|
|||||||
RUN rm install_openssl.sh
|
RUN rm install_openssl.sh
|
||||||
|
|
||||||
ARG INDUCTOR_BENCHMARKS
|
ARG INDUCTOR_BENCHMARKS
|
||||||
ARG ANACONDA_PYTHON_VERSION
|
|
||||||
ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION
|
|
||||||
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
|
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
|
||||||
COPY ./common/common_utils.sh common_utils.sh
|
COPY ./common/common_utils.sh common_utils.sh
|
||||||
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
|
||||||
COPY ci_commit_pins/timm.txt timm.txt
|
COPY ci_commit_pins/timm.txt timm.txt
|
||||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
|
||||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt
|
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt
|
||||||
|
|
||||||
# Install XPU Dependencies
|
# Install XPU Dependencies
|
||||||
ARG XPU_VERSION
|
ARG XPU_VERSION
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
SHELL=/usr/bin/env bash
|
SHELL=/usr/bin/env bash
|
||||||
|
|
||||||
DOCKER_CMD ?= docker
|
DOCKER_CMD ?= docker
|
||||||
DESIRED_ROCM ?= 7.1
|
DESIRED_ROCM ?= 7.0
|
||||||
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
|
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
|
||||||
PACKAGE_NAME = magma-rocm
|
PACKAGE_NAME = magma-rocm
|
||||||
# inherit this from underlying docker image, do not pass this env var to docker
|
# inherit this from underlying docker image, do not pass this env var to docker
|
||||||
@ -16,7 +16,6 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
|
|||||||
magma-rocm/build_magma.sh
|
magma-rocm/build_magma.sh
|
||||||
|
|
||||||
.PHONY: all
|
.PHONY: all
|
||||||
all: magma-rocm71
|
|
||||||
all: magma-rocm70
|
all: magma-rocm70
|
||||||
all: magma-rocm64
|
all: magma-rocm64
|
||||||
|
|
||||||
@ -25,11 +24,6 @@ clean:
|
|||||||
$(RM) -r magma-*
|
$(RM) -r magma-*
|
||||||
$(RM) -r output
|
$(RM) -r output
|
||||||
|
|
||||||
.PHONY: magma-rocm71
|
|
||||||
magma-rocm71: DESIRED_ROCM := 7.1
|
|
||||||
magma-rocm71:
|
|
||||||
$(DOCKER_RUN)
|
|
||||||
|
|
||||||
.PHONY: magma-rocm70
|
.PHONY: magma-rocm70
|
||||||
magma-rocm70: DESIRED_ROCM := 7.0
|
magma-rocm70: DESIRED_ROCM := 7.0
|
||||||
magma-rocm70:
|
magma-rocm70:
|
||||||
|
|||||||
@ -426,7 +426,7 @@ fi
|
|||||||
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
|
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
|
||||||
# export test times so that potential sharded tests that'll branch off this build will use consistent data
|
# export test times so that potential sharded tests that'll branch off this build will use consistent data
|
||||||
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
|
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
|
||||||
PYTHONPATH=. python tools/stats/export_test_times.py
|
python tools/stats/export_test_times.py
|
||||||
fi
|
fi
|
||||||
# don't do this for bazel or s390x or riscv64 as they don't use sccache
|
# don't do this for bazel or s390x or riscv64 as they don't use sccache
|
||||||
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then
|
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then
|
||||||
|
|||||||
@ -89,41 +89,23 @@ if [ "$is_main_doc" = true ]; then
|
|||||||
|
|
||||||
make coverage
|
make coverage
|
||||||
# Now we have the coverage report, we need to make sure it is empty.
|
# Now we have the coverage report, we need to make sure it is empty.
|
||||||
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
|
# Count the number of lines in the file and turn that number into a variable
|
||||||
# showing the undocumented count in the third column.
|
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
|
||||||
# Example: | TOTAL | 99.83% | 2 |
|
# Skip the report header by subtracting 2: the header will be output even if
|
||||||
|
# there are no undocumented items.
|
||||||
#
|
#
|
||||||
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
|
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
|
||||||
# be documented then removed from there.
|
# be documented then removed from there.
|
||||||
|
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
|
||||||
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
|
undocumented=$((lines - 2))
|
||||||
# The table format is: | Module | Coverage | Undocumented |
|
if [ $undocumented -lt 0 ]; then
|
||||||
# Extract the third column (undocumented count) from the TOTAL row
|
|
||||||
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
|
|
||||||
|
|
||||||
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
|
|
||||||
echo coverage output not found
|
echo coverage output not found
|
||||||
exit 1
|
exit 1
|
||||||
elif [ "$undocumented" -gt 0 ]; then
|
elif [ $undocumented -gt 0 ]; then
|
||||||
set +x # Disable command echoing for cleaner output
|
echo undocumented objects found:
|
||||||
echo ""
|
cat build/coverage/python.txt
|
||||||
echo "====================="
|
|
||||||
echo "UNDOCUMENTED OBJECTS:"
|
|
||||||
echo "====================="
|
|
||||||
echo ""
|
|
||||||
# Find the line number of the TOTAL row and print only what comes after it
|
|
||||||
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
|
|
||||||
if [ -n "$total_line" ]; then
|
|
||||||
# Print only the detailed list (skip the statistics table)
|
|
||||||
tail -n +$((total_line + 2)) build/coverage/python.txt
|
|
||||||
else
|
|
||||||
# Fallback to showing entire file if TOTAL line not found
|
|
||||||
cat build/coverage/python.txt
|
|
||||||
fi
|
|
||||||
echo ""
|
|
||||||
echo "Make sure you've updated relevant .rsts in docs/source!"
|
echo "Make sure you've updated relevant .rsts in docs/source!"
|
||||||
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
|
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
|
||||||
set -x # Re-enable command echoing
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
|
|||||||
@ -337,7 +337,7 @@ test_python() {
|
|||||||
|
|
||||||
test_python_smoke() {
|
test_python_smoke() {
|
||||||
# Smoke tests for H100/B200
|
# Smoke tests for H100/B200
|
||||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||||
assert_git_not_dirty
|
assert_git_not_dirty
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -572,8 +572,6 @@ fi
|
|||||||
|
|
||||||
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
|
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
|
||||||
DYNAMO_BENCHMARK_FLAGS+=(--device cpu)
|
DYNAMO_BENCHMARK_FLAGS+=(--device cpu)
|
||||||
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
|
|
||||||
DYNAMO_BENCHMARK_FLAGS+=(--device xpu)
|
|
||||||
else
|
else
|
||||||
DYNAMO_BENCHMARK_FLAGS+=(--device cuda)
|
DYNAMO_BENCHMARK_FLAGS+=(--device cuda)
|
||||||
fi
|
fi
|
||||||
@ -667,8 +665,6 @@ test_perf_for_dashboard() {
|
|||||||
device=cuda_b200
|
device=cuda_b200
|
||||||
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
|
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
|
||||||
device=rocm
|
device=rocm
|
||||||
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
|
|
||||||
device=xpu
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
for mode in "${modes[@]}"; do
|
for mode in "${modes[@]}"; do
|
||||||
@ -1653,7 +1649,7 @@ test_operator_microbenchmark() {
|
|||||||
|
|
||||||
cd "${TEST_DIR}"/benchmarks/operator_benchmark
|
cd "${TEST_DIR}"/benchmarks/operator_benchmark
|
||||||
|
|
||||||
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do
|
for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do
|
||||||
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
|
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
|
||||||
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
|
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
|
||||||
--benchmark-name "PyTorch operator microbenchmark" --use-compile
|
--benchmark-name "PyTorch operator microbenchmark" --use-compile
|
||||||
@ -1761,7 +1757,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
|
|||||||
else
|
else
|
||||||
# Do this after checkout_install_torchbench to ensure we clobber any
|
# Do this after checkout_install_torchbench to ensure we clobber any
|
||||||
# nightlies that torchbench may pull in
|
# nightlies that torchbench may pull in
|
||||||
if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* ]]; then
|
if [[ "${TEST_CONFIG}" != *cpu* ]]; then
|
||||||
install_torchrec_and_fbgemm
|
install_torchrec_and_fbgemm
|
||||||
fi
|
fi
|
||||||
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"
|
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"
|
||||||
|
|||||||
@ -70,7 +70,7 @@ sccache --zero-stats
|
|||||||
sccache --show-stats
|
sccache --show-stats
|
||||||
|
|
||||||
# Build the wheel
|
# Build the wheel
|
||||||
python -m build --wheel --no-isolation
|
python -m build --wheel --no-build-isolation
|
||||||
if ($LASTEXITCODE -ne 0) { exit 1 }
|
if ($LASTEXITCODE -ne 0) { exit 1 }
|
||||||
|
|
||||||
# Install the wheel locally
|
# Install the wheel locally
|
||||||
|
|||||||
@ -60,11 +60,9 @@ performance-*,
|
|||||||
readability-container-size-empty,
|
readability-container-size-empty,
|
||||||
readability-delete-null-pointer,
|
readability-delete-null-pointer,
|
||||||
readability-duplicate-include,
|
readability-duplicate-include,
|
||||||
readability-named-parameter,
|
|
||||||
readability-misplaced-array-index,
|
readability-misplaced-array-index,
|
||||||
readability-redundant*,
|
readability-redundant*,
|
||||||
readability-simplify-subscript-expr,
|
readability-simplify-subscript-expr,
|
||||||
readability-static-definition-in-anonymous-namespace
|
|
||||||
readability-string-compare,
|
readability-string-compare,
|
||||||
-readability-redundant-access-specifiers,
|
-readability-redundant-access-specifiers,
|
||||||
-readability-redundant-control-flow,
|
-readability-redundant-control-flow,
|
||||||
|
|||||||
@ -1,319 +0,0 @@
|
|||||||
---
|
|
||||||
name: add-uint-support
|
|
||||||
description: Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support.
|
|
||||||
---
|
|
||||||
|
|
||||||
# Add Unsigned Integer (uint) Support to Operators
|
|
||||||
|
|
||||||
This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros.
|
|
||||||
|
|
||||||
## When to use this skill
|
|
||||||
|
|
||||||
Use this skill when:
|
|
||||||
- Adding uint16, uint32, or uint64 support to an operator
|
|
||||||
- User mentions "unsigned types", "uint support", "barebones unsigned types"
|
|
||||||
- Enabling support for kUInt16, kUInt32, kUInt64 in kernels
|
|
||||||
- Working with operator implementations that need expanded type coverage
|
|
||||||
|
|
||||||
## Quick reference
|
|
||||||
|
|
||||||
**Add unsigned types to existing dispatch:**
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES));
|
|
||||||
|
|
||||||
// After (method 1: add unsigned types explicitly)
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
|
|
||||||
|
|
||||||
// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
|
|
||||||
```
|
|
||||||
|
|
||||||
## Type group reference
|
|
||||||
|
|
||||||
**Unsigned type groups:**
|
|
||||||
- `AT_BAREBONES_UNSIGNED_TYPES`: kUInt16, kUInt32, kUInt64
|
|
||||||
- `AT_INTEGRAL_TYPES_V2`: AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
|
|
||||||
|
|
||||||
**Relationship:**
|
|
||||||
```cpp
|
|
||||||
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
|
|
||||||
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
|
|
||||||
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES
|
|
||||||
```
|
|
||||||
|
|
||||||
## Instructions
|
|
||||||
|
|
||||||
### Step 1: Determine if conversion to V2 is needed
|
|
||||||
|
|
||||||
Check if the file uses AT_DISPATCH_V2:
|
|
||||||
|
|
||||||
**If using old AT_DISPATCH:**
|
|
||||||
- First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill
|
|
||||||
- Then proceed with adding uint support
|
|
||||||
|
|
||||||
**If already using AT_DISPATCH_V2:**
|
|
||||||
- Proceed directly to Step 2
|
|
||||||
|
|
||||||
### Step 2: Analyze the current dispatch macro
|
|
||||||
|
|
||||||
Identify what type groups are currently in use:
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
// body
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
Current type coverage
|
|
||||||
```
|
|
||||||
|
|
||||||
Common patterns:
|
|
||||||
- `AT_EXPAND(AT_ALL_TYPES)` → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
|
|
||||||
- `AT_EXPAND(AT_INTEGRAL_TYPES)` → signed integers only
|
|
||||||
- `AT_EXPAND(AT_FLOATING_TYPES)` → floating point types
|
|
||||||
|
|
||||||
### Step 3: Choose the uint addition method
|
|
||||||
|
|
||||||
Two approaches:
|
|
||||||
|
|
||||||
**Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly**
|
|
||||||
- Use when: You want to be explicit about adding uint support
|
|
||||||
- Add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the type list
|
|
||||||
|
|
||||||
**Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2**
|
|
||||||
- Use when: The dispatch already uses `AT_EXPAND(AT_INTEGRAL_TYPES)`
|
|
||||||
- More concise: replaces one type group with its superset
|
|
||||||
- Only applicable if AT_INTEGRAL_TYPES is present
|
|
||||||
|
|
||||||
### Step 4: Apply the transformation
|
|
||||||
|
|
||||||
**Method 1 example:**
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_V2(
|
|
||||||
dtype,
|
|
||||||
"min_values_cuda",
|
|
||||||
AT_WRAP([&]() {
|
|
||||||
kernel_impl<scalar_t>(iter);
|
|
||||||
}),
|
|
||||||
AT_EXPAND(AT_ALL_TYPES),
|
|
||||||
kBFloat16, kHalf, kBool
|
|
||||||
);
|
|
||||||
|
|
||||||
// After (add unsigned types)
|
|
||||||
AT_DISPATCH_V2(
|
|
||||||
dtype,
|
|
||||||
"min_values_cuda",
|
|
||||||
AT_WRAP([&]() {
|
|
||||||
kernel_impl<scalar_t>(iter);
|
|
||||||
}),
|
|
||||||
AT_EXPAND(AT_ALL_TYPES),
|
|
||||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
|
||||||
kBFloat16, kHalf, kBool
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
**Method 2 example:**
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_V2(
|
|
||||||
dtype,
|
|
||||||
"integral_op",
|
|
||||||
AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}),
|
|
||||||
AT_EXPAND(AT_INTEGRAL_TYPES)
|
|
||||||
);
|
|
||||||
|
|
||||||
// After (substitute with V2)
|
|
||||||
AT_DISPATCH_V2(
|
|
||||||
dtype,
|
|
||||||
"integral_op",
|
|
||||||
AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}),
|
|
||||||
AT_EXPAND(AT_INTEGRAL_TYPES_V2)
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 5: Handle AT_ALL_TYPES vs individual type groups
|
|
||||||
|
|
||||||
If the dispatch uses `AT_EXPAND(AT_ALL_TYPES)`:
|
|
||||||
- `AT_ALL_TYPES` = `AT_INTEGRAL_TYPES` + `AT_FLOATING_TYPES`
|
|
||||||
- To add uint: add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the list
|
|
||||||
|
|
||||||
If the dispatch separately lists INTEGRAL and FLOATING:
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
|
||||||
|
|
||||||
// After (Method 2 preferred)
|
|
||||||
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 6: Verify all dispatch sites
|
|
||||||
|
|
||||||
Check the file for ALL dispatch macros that need uint support:
|
|
||||||
- Some operators have multiple dispatch sites (CPU, CUDA, different functions)
|
|
||||||
- Apply the transformation consistently across all sites
|
|
||||||
- Ensure each gets the same type coverage updates
|
|
||||||
|
|
||||||
### Step 7: Validate the changes
|
|
||||||
|
|
||||||
Check that:
|
|
||||||
- [ ] AT_DISPATCH_V2 format is used (not old AT_DISPATCH)
|
|
||||||
- [ ] Unsigned types are added via one of the two methods
|
|
||||||
- [ ] All relevant dispatch sites in the file are updated
|
|
||||||
- [ ] Type groups use `AT_EXPAND()`
|
|
||||||
- [ ] Arguments are properly formatted and comma-separated
|
|
||||||
|
|
||||||
## Common patterns
|
|
||||||
|
|
||||||
### Pattern 1: AT_ALL_TYPES + extras
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
|
||||||
|
|
||||||
// After
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pattern 2: Separate INTEGRAL + FLOATING
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
|
|
||||||
|
|
||||||
// After
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pattern 3: Old dispatch needs conversion first
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before (needs v2 conversion first)
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
});
|
|
||||||
|
|
||||||
// After v2 conversion
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
|
||||||
|
|
||||||
// After adding uint support
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
|
|
||||||
```
|
|
||||||
|
|
||||||
## Multiple dispatch sites example
|
|
||||||
|
|
||||||
For a file with multiple functions:
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
void min_values_kernel_cuda(TensorIterator& iter) {
|
|
||||||
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
|
|
||||||
impl<scalar_t>(iter);
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
|
||||||
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
// Added uint support
|
|
||||||
}
|
|
||||||
|
|
||||||
void min_launch_kernel(TensorIterator &iter) {
|
|
||||||
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
|
|
||||||
gpu_reduce_kernel<scalar_t>(iter);
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
|
|
||||||
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
// Added uint support here too
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Decision tree
|
|
||||||
|
|
||||||
Use this decision tree to determine the approach:
|
|
||||||
|
|
||||||
```
|
|
||||||
Is the file using AT_DISPATCH_V2?
|
|
||||||
├─ No → Use at-dispatch-v2 skill first, then continue
|
|
||||||
└─ Yes
|
|
||||||
└─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)?
|
|
||||||
├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2)
|
|
||||||
└─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list
|
|
||||||
```
|
|
||||||
|
|
||||||
## Edge cases
|
|
||||||
|
|
||||||
### Case 1: Dispatch with only floating types
|
|
||||||
|
|
||||||
If the operator only supports floating point types, don't add uint support:
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Leave as-is - floating point only operator
|
|
||||||
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Case 2: Complex types present
|
|
||||||
|
|
||||||
Unsigned types work alongside complex types:
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES),
|
|
||||||
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
|
|
||||||
AT_EXPAND(AT_COMPLEX_TYPES),
|
|
||||||
kHalf, kBFloat16);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Case 3: Already has uint support
|
|
||||||
|
|
||||||
Check if uint types are already present:
|
|
||||||
- If `AT_INTEGRAL_TYPES_V2` is used → already has uint support
|
|
||||||
- If `AT_BAREBONES_UNSIGNED_TYPES` is already in list → already has uint support
|
|
||||||
- Skip the file if uint support is already present
|
|
||||||
|
|
||||||
## Workflow
|
|
||||||
|
|
||||||
When asked to add uint support:
|
|
||||||
|
|
||||||
1. Read the target file
|
|
||||||
2. Check if using AT_DISPATCH_V2:
|
|
||||||
- If not → use at-dispatch-v2 skill first
|
|
||||||
3. Identify all dispatch macro sites
|
|
||||||
4. For each dispatch:
|
|
||||||
- Analyze current type groups
|
|
||||||
- Choose method (add BAREBONES_UNSIGNED or upgrade to V2)
|
|
||||||
- Apply transformation with Edit tool
|
|
||||||
5. Show the user the changes
|
|
||||||
6. Explain what was modified
|
|
||||||
|
|
||||||
## Important notes
|
|
||||||
|
|
||||||
- Always check if v2 conversion is needed first
|
|
||||||
- Apply changes consistently across all dispatch sites in the file
|
|
||||||
- Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable
|
|
||||||
- Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit
|
|
||||||
- Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8)
|
|
||||||
- Some operators may not semantically support unsigned types - use judgment
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing.
|
|
||||||
@ -1,305 +0,0 @@
|
|||||||
---
|
|
||||||
name: at-dispatch-v2
|
|
||||||
description: Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations.
|
|
||||||
---
|
|
||||||
|
|
||||||
# AT_DISPATCH to AT_DISPATCH_V2 Converter
|
|
||||||
|
|
||||||
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in `aten/src/ATen/Dispatch_v2.h`.
|
|
||||||
|
|
||||||
## When to use this skill
|
|
||||||
|
|
||||||
Use this skill when:
|
|
||||||
- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
|
|
||||||
- Porting ATen kernels to use the new dispatch API
|
|
||||||
- Working with files in `aten/src/ATen/native/` that use dispatch macros
|
|
||||||
- User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
|
|
||||||
|
|
||||||
## Quick reference
|
|
||||||
|
|
||||||
**Old format:**
|
|
||||||
```cpp
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
|
|
||||||
// lambda body
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
**New format:**
|
|
||||||
```cpp
|
|
||||||
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
|
|
||||||
// lambda body
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
|
|
||||||
```
|
|
||||||
|
|
||||||
## Key transformations
|
|
||||||
|
|
||||||
1. **Reorder arguments**: `scalar_type` and `name` come first, then lambda, then types
|
|
||||||
2. **Wrap the lambda**: Use `AT_WRAP(lambda)` to handle internal commas
|
|
||||||
3. **Expand type groups**: Use `AT_EXPAND(AT_ALL_TYPES)` instead of implicit expansion
|
|
||||||
4. **List individual types**: Add extra types (kHalf, kBFloat16, etc.) after expanded groups
|
|
||||||
5. **Add include**: `#include <ATen/Dispatch_v2.h>` near other Dispatch includes
|
|
||||||
|
|
||||||
## Instructions
|
|
||||||
|
|
||||||
### Step 1: Add the Dispatch_v2.h include
|
|
||||||
|
|
||||||
Add the v2 header near the existing `#include <ATen/Dispatch.h>`:
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
#include <ATen/Dispatch.h>
|
|
||||||
#include <ATen/Dispatch_v2.h>
|
|
||||||
```
|
|
||||||
|
|
||||||
Keep the old Dispatch.h include for now (other code may still need it).
|
|
||||||
|
|
||||||
### Step 2: Identify the old dispatch pattern
|
|
||||||
|
|
||||||
Common patterns to convert:
|
|
||||||
|
|
||||||
- `AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)`
|
|
||||||
- `AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)`
|
|
||||||
- `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)`
|
|
||||||
- `AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)`
|
|
||||||
|
|
||||||
### Step 3: Map the old macro to type groups
|
|
||||||
|
|
||||||
Identify which type group macro corresponds to the base types:
|
|
||||||
|
|
||||||
| Old macro base | AT_DISPATCH_V2 type group |
|
|
||||||
|----------------|---------------------------|
|
|
||||||
| `ALL_TYPES` | `AT_EXPAND(AT_ALL_TYPES)` |
|
|
||||||
| `FLOATING_TYPES` | `AT_EXPAND(AT_FLOATING_TYPES)` |
|
|
||||||
| `INTEGRAL_TYPES` | `AT_EXPAND(AT_INTEGRAL_TYPES)` |
|
|
||||||
| `COMPLEX_TYPES` | `AT_EXPAND(AT_COMPLEX_TYPES)` |
|
|
||||||
| `ALL_TYPES_AND_COMPLEX` | `AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)` |
|
|
||||||
|
|
||||||
For combined patterns, use multiple `AT_EXPAND()` entries:
|
|
||||||
```cpp
|
|
||||||
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
|
|
||||||
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 4: Extract the individual types
|
|
||||||
|
|
||||||
From `AT_DISPATCH_*_AND2(type1, type2, ...)` or `AT_DISPATCH_*_AND3(type1, type2, type3, ...)`, extract the individual types (type1, type2, etc.).
|
|
||||||
|
|
||||||
These become the trailing arguments after the type group:
|
|
||||||
```cpp
|
|
||||||
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
Individual types from AND3
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 5: Transform to AT_DISPATCH_V2
|
|
||||||
|
|
||||||
Apply the transformation:
|
|
||||||
|
|
||||||
**Pattern:**
|
|
||||||
```cpp
|
|
||||||
AT_DISPATCH_V2(
|
|
||||||
scalar_type, // 1st: The dtype expression
|
|
||||||
"name", // 2nd: The debug string
|
|
||||||
AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP
|
|
||||||
type_groups, // 4th+: Type groups with AT_EXPAND()
|
|
||||||
individual_types // Last: Individual types
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Example transformation:**
|
|
||||||
```cpp
|
|
||||||
// BEFORE
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND3(
|
|
||||||
kBFloat16, kHalf, kBool,
|
|
||||||
iter.dtype(),
|
|
||||||
"min_values_cuda",
|
|
||||||
[&]() {
|
|
||||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
// AFTER
|
|
||||||
AT_DISPATCH_V2(
|
|
||||||
iter.dtype(),
|
|
||||||
"min_values_cuda",
|
|
||||||
AT_WRAP([&]() {
|
|
||||||
min_values_kernel_cuda_impl<scalar_t>(iter);
|
|
||||||
}),
|
|
||||||
AT_EXPAND(AT_ALL_TYPES),
|
|
||||||
kBFloat16, kHalf, kBool
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 6: Handle multi-line lambdas
|
|
||||||
|
|
||||||
For lambdas with internal commas or complex expressions, AT_WRAP is essential:
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
AT_DISPATCH_V2(
|
|
||||||
dtype,
|
|
||||||
"complex_kernel",
|
|
||||||
AT_WRAP([&]() {
|
|
||||||
gpu_reduce_kernel<scalar_t, scalar_t>(
|
|
||||||
iter,
|
|
||||||
MinOps<scalar_t>{},
|
|
||||||
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // Commas inside!
|
|
||||||
);
|
|
||||||
}),
|
|
||||||
AT_EXPAND(AT_ALL_TYPES)
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 7: Verify the conversion
|
|
||||||
|
|
||||||
Check that:
|
|
||||||
- [ ] `AT_WRAP()` wraps the entire lambda
|
|
||||||
- [ ] Type groups use `AT_EXPAND()`
|
|
||||||
- [ ] Individual types don't have `AT_EXPAND()` (just `kBFloat16`, not `AT_EXPAND(kBFloat16)`)
|
|
||||||
- [ ] Argument order is: scalar_type, name, lambda, types
|
|
||||||
- [ ] Include added: `#include <ATen/Dispatch_v2.h>`
|
|
||||||
|
|
||||||
## Type group reference
|
|
||||||
|
|
||||||
Available type group macros (use with `AT_EXPAND()`):
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
|
|
||||||
AT_FLOATING_TYPES // kDouble, kFloat
|
|
||||||
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
|
|
||||||
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
|
|
||||||
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
|
|
||||||
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
|
|
||||||
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types
|
|
||||||
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
|
|
||||||
AT_FLOAT8_TYPES // Float8 variants
|
|
||||||
```
|
|
||||||
|
|
||||||
## Common patterns
|
|
||||||
|
|
||||||
### Pattern: AT_DISPATCH_ALL_TYPES_AND2
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
|
|
||||||
kernel<scalar_t>(data);
|
|
||||||
});
|
|
||||||
|
|
||||||
// After
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>(data);
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pattern: AT_DISPATCH_FLOATING_TYPES_AND3
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
|
|
||||||
tensor.scalar_type(), "float_op", [&] {
|
|
||||||
process<scalar_t>(tensor);
|
|
||||||
});
|
|
||||||
|
|
||||||
// After
|
|
||||||
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
|
|
||||||
process<scalar_t>(tensor);
|
|
||||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
|
|
||||||
kComplexHalf, kHalf,
|
|
||||||
self.scalar_type(),
|
|
||||||
"complex_op",
|
|
||||||
[&] {
|
|
||||||
result = compute<scalar_t>(self);
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
// After
|
|
||||||
AT_DISPATCH_V2(
|
|
||||||
self.scalar_type(),
|
|
||||||
"complex_op",
|
|
||||||
AT_WRAP([&] {
|
|
||||||
result = compute<scalar_t>(self);
|
|
||||||
}),
|
|
||||||
AT_EXPAND(AT_ALL_TYPES),
|
|
||||||
AT_EXPAND(AT_COMPLEX_TYPES),
|
|
||||||
kComplexHalf,
|
|
||||||
kHalf
|
|
||||||
);
|
|
||||||
```
|
|
||||||
|
|
||||||
## Edge cases
|
|
||||||
|
|
||||||
### Case 1: No extra types (rare)
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
|
|
||||||
|
|
||||||
// After
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES));
|
|
||||||
```
|
|
||||||
|
|
||||||
### Case 2: Many individual types (AND4, AND5, etc.)
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
|
|
||||||
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
|
|
||||||
|
|
||||||
// After
|
|
||||||
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
|
|
||||||
kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
|
|
||||||
```
|
|
||||||
|
|
||||||
### Case 3: Lambda with no captures
|
|
||||||
|
|
||||||
```cpp
|
|
||||||
// Before
|
|
||||||
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
|
|
||||||
static_kernel<scalar_t>();
|
|
||||||
});
|
|
||||||
|
|
||||||
// After
|
|
||||||
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
|
|
||||||
static_kernel<scalar_t>();
|
|
||||||
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
|
|
||||||
```
|
|
||||||
|
|
||||||
## Benefits of AT_DISPATCH_V2
|
|
||||||
|
|
||||||
1. **No arity in macro name**: Don't need different macros for AND2, AND3, AND4
|
|
||||||
2. **Composable type sets**: Mix and match type groups with `AT_EXPAND()`
|
|
||||||
3. **Extensible**: Easy to add more types without hitting macro limits
|
|
||||||
4. **Clearer**: Type groups are explicit, not implicit in macro name
|
|
||||||
|
|
||||||
## Important notes
|
|
||||||
|
|
||||||
- Keep `#include <ATen/Dispatch.h>` - other code may need it
|
|
||||||
- The `AT_WRAP()` is mandatory - prevents comma parsing issues in the lambda
|
|
||||||
- Type groups need `AT_EXPAND()`, individual types don't
|
|
||||||
- The v2 API is in `aten/src/ATen/Dispatch_v2.h` - refer to it for full docs
|
|
||||||
- See the header file for the Python script to regenerate the macro implementation
|
|
||||||
|
|
||||||
## Workflow
|
|
||||||
|
|
||||||
When asked to convert AT_DISPATCH macros:
|
|
||||||
|
|
||||||
1. Read the file to identify all AT_DISPATCH uses
|
|
||||||
2. Add `#include <ATen/Dispatch_v2.h>` if not present
|
|
||||||
3. For each dispatch macro:
|
|
||||||
- Identify the pattern and extract components
|
|
||||||
- Map the base type group
|
|
||||||
- Extract individual types
|
|
||||||
- Construct the AT_DISPATCH_V2 call
|
|
||||||
- Apply with Edit tool
|
|
||||||
4. Show the user the complete converted file
|
|
||||||
5. Explain what was changed
|
|
||||||
|
|
||||||
Do NOT compile or test the code - focus on accurate conversion only.
|
|
||||||
@ -1,11 +1,11 @@
|
|||||||
name: 🚀 New Feature for Release
|
name: 🚀 Release highlight for proposed Feature
|
||||||
description: Submit a Release highlight for proposed Feature
|
description: Submit a Release highlight for proposed Feature
|
||||||
labels: ["release-feature-request"]
|
labels: ["release-feature-request"]
|
||||||
|
|
||||||
body:
|
body:
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: New Feature for Release
|
label: Release highlight for proposed Feature
|
||||||
description: >
|
description: >
|
||||||
Example: “A torch.special module, analogous to SciPy's special module.”
|
Example: “A torch.special module, analogous to SciPy's special module.”
|
||||||
- type: input
|
- type: input
|
||||||
|
|||||||
4
.github/actions/diskspace-cleanup/action.yml
vendored
4
.github/actions/diskspace-cleanup/action.yml
vendored
@ -27,9 +27,7 @@ runs:
|
|||||||
docker system prune -af
|
docker system prune -af
|
||||||
diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
|
diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
|
||||||
if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then
|
if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then
|
||||||
diskspace_cutoff_int=$((diskspace_cutoff + 0))
|
echo "Error: Available diskspace is less than $diskspace_cutoff percent. Not enough diskspace."
|
||||||
difference=$((100 - diskspace_cutoff_int))
|
|
||||||
echo "Error: Available diskspace is less than $difference percent. Not enough diskspace."
|
|
||||||
echo "$msg"
|
echo "$msg"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
|
|||||||
12
.github/actions/pytest-cache-download/action.yml
vendored
12
.github/actions/pytest-cache-download/action.yml
vendored
@ -38,9 +38,9 @@ runs:
|
|||||||
run: |
|
run: |
|
||||||
python3 .github/scripts/pytest_cache.py \
|
python3 .github/scripts/pytest_cache.py \
|
||||||
--download \
|
--download \
|
||||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||||
--pr_identifier "$GITHUB_REF" \
|
--pr_identifier $GITHUB_REF \
|
||||||
--job_identifier "$JOB_IDENTIFIER" \
|
--job_identifier $JOB_IDENTIFIER \
|
||||||
--temp_dir "$RUNNER_TEMP" \
|
--temp_dir $RUNNER_TEMP \
|
||||||
--repo "$REPO" \
|
--repo $REPO \
|
||||||
--bucket "$BUCKET" \
|
--bucket $BUCKET \
|
||||||
|
|||||||
16
.github/actions/pytest-cache-upload/action.yml
vendored
16
.github/actions/pytest-cache-upload/action.yml
vendored
@ -47,11 +47,11 @@ runs:
|
|||||||
run: |
|
run: |
|
||||||
python3 .github/scripts/pytest_cache.py \
|
python3 .github/scripts/pytest_cache.py \
|
||||||
--upload \
|
--upload \
|
||||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||||
--pr_identifier "$GITHUB_REF" \
|
--pr_identifier $GITHUB_REF \
|
||||||
--job_identifier "$JOB_IDENTIFIER" \
|
--job_identifier $JOB_IDENTIFIER \
|
||||||
--sha "$SHA" \
|
--sha $SHA \
|
||||||
--test_config "$TEST_CONFIG" \
|
--test_config $TEST_CONFIG \
|
||||||
--shard "$SHARD" \
|
--shard $SHARD \
|
||||||
--repo "$REPO" \
|
--repo $REPO \
|
||||||
--temp_dir "$RUNNER_TEMP" \
|
--temp_dir $RUNNER_TEMP \
|
||||||
|
|||||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
|||||||
ad5816f0eee1c873df1b7d371c69f1f811a89387
|
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2
|
||||||
|
|||||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
|||||||
cfbc5c2f1c798991715a6b06bb3ce46478c4487c
|
218d2ab791d437309f91e0486eb9fa7f00badc17
|
||||||
|
|||||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
|||||||
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
|
df6798dfb931ce7c7fe5bed2447cd1092a5981af
|
||||||
|
|||||||
125
.github/copilot-instructions.md
vendored
125
.github/copilot-instructions.md
vendored
@ -1,125 +0,0 @@
|
|||||||
# PyTorch Copilot Instructions
|
|
||||||
|
|
||||||
This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively.
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
### Core Components
|
|
||||||
|
|
||||||
- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality
|
|
||||||
- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd
|
|
||||||
- `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse)
|
|
||||||
- `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry
|
|
||||||
- **torch/** - Python bindings and public API
|
|
||||||
- `torch/csrc/` - C++ Python bindings (hand-written and generated)
|
|
||||||
- `torch/csrc/autograd/` - Reverse-mode automatic differentiation
|
|
||||||
- `torch/csrc/jit/` - TorchScript JIT compiler
|
|
||||||
- **torchgen/** - Code generation tooling that reads `native_functions.yaml`
|
|
||||||
- **tools/** - Build scripts, autograd derivatives, code generation
|
|
||||||
|
|
||||||
### The Code Generation Workflow
|
|
||||||
|
|
||||||
**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file:
|
|
||||||
1. Declares operator signatures, variants (function/method), and dispatch behavior
|
|
||||||
2. Gets processed by `torchgen/` to generate C++/Python bindings
|
|
||||||
3. Produces headers in `build/aten/src/ATen/` during compilation
|
|
||||||
|
|
||||||
Example entry structure:
|
|
||||||
```yaml
|
|
||||||
- func: my_op(Tensor self, Scalar alpha=1) -> Tensor
|
|
||||||
variants: function, method
|
|
||||||
dispatch:
|
|
||||||
CPU: my_op_cpu
|
|
||||||
CUDA: my_op_cuda
|
|
||||||
```
|
|
||||||
|
|
||||||
After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`).
|
|
||||||
|
|
||||||
## Development Workflows
|
|
||||||
|
|
||||||
### Building from Source
|
|
||||||
|
|
||||||
**Never run `setup.py` directly** - use pip with editable install:
|
|
||||||
```bash
|
|
||||||
python -m pip install --no-build-isolation -v -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
Speed up builds:
|
|
||||||
- `DEBUG=1` - Debug symbols with `-g -O0`
|
|
||||||
- `USE_CUDA=0` - Skip CUDA compilation
|
|
||||||
- `BUILD_TEST=0` - Skip C++ test binaries
|
|
||||||
- Install `ninja` (`pip install ninja`) for faster builds
|
|
||||||
- Use `ccache` for incremental compilation caching
|
|
||||||
|
|
||||||
Rebuild specific targets: `(cd build && ninja <target>)`
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
**Critical**: DO NOT run entire test suites. Run specific tests only:
|
|
||||||
```bash
|
|
||||||
python test/test_torch.py TestTorch.test_specific_case
|
|
||||||
```
|
|
||||||
|
|
||||||
**Test structure**: All tests use `torch.testing._internal.common_utils`:
|
|
||||||
```python
|
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
||||||
|
|
||||||
class TestFeature(TestCase):
|
|
||||||
def test_something(self):
|
|
||||||
# Use self.assertEqual for tensor comparisons
|
|
||||||
pass
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_tests()
|
|
||||||
```
|
|
||||||
|
|
||||||
**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file.
|
|
||||||
|
|
||||||
### Linting
|
|
||||||
|
|
||||||
Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes)
|
|
||||||
|
|
||||||
## Project-Specific Conventions
|
|
||||||
|
|
||||||
### Memory and Storage
|
|
||||||
- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs)
|
|
||||||
- CUDA device info lives in storage objects
|
|
||||||
|
|
||||||
### Python-C++ Integration (`torch/csrc/`)
|
|
||||||
- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors
|
|
||||||
- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr`
|
|
||||||
- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion
|
|
||||||
|
|
||||||
### Dispatch System
|
|
||||||
- PyTorch uses operator dispatch to route calls to backend-specific kernels
|
|
||||||
- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops
|
|
||||||
- See `aten/src/ATen/native/README.md` for dispatch keyword guidance
|
|
||||||
|
|
||||||
## Git Workflow (AI Agent Specific)
|
|
||||||
|
|
||||||
When preparing PRs from this environment:
|
|
||||||
```bash
|
|
||||||
git stash -u
|
|
||||||
git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch
|
|
||||||
git stash pop
|
|
||||||
# Resolve conflicts if necessary
|
|
||||||
```
|
|
||||||
|
|
||||||
## Common Gotchas
|
|
||||||
|
|
||||||
1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml`
|
|
||||||
2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI
|
|
||||||
3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux)
|
|
||||||
4. **No internet access** - DO NOT attempt to install dependencies during development
|
|
||||||
|
|
||||||
## Key Files Reference
|
|
||||||
|
|
||||||
- `AGENTS.md` - Instructions specific to AI coding agents
|
|
||||||
- `CONTRIBUTING.md` - Comprehensive human contributor guide
|
|
||||||
- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript)
|
|
||||||
- `aten/src/ATen/native/README.md` - Operator implementation guide
|
|
||||||
- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd
|
|
||||||
|
|
||||||
## Performance Debugging
|
|
||||||
|
|
||||||
Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation.
|
|
||||||
1
.github/pytorch-probot.yml
vendored
1
.github/pytorch-probot.yml
vendored
@ -19,7 +19,6 @@ ciflow_push_tags:
|
|||||||
- ciflow/inductor-perf-test-nightly-rocm-mi300
|
- ciflow/inductor-perf-test-nightly-rocm-mi300
|
||||||
- ciflow/inductor-perf-test-nightly-rocm-mi355
|
- ciflow/inductor-perf-test-nightly-rocm-mi355
|
||||||
- ciflow/inductor-perf-test-nightly-x86-zen
|
- ciflow/inductor-perf-test-nightly-x86-zen
|
||||||
- ciflow/inductor-perf-test-nightly-xpu
|
|
||||||
- ciflow/inductor-periodic
|
- ciflow/inductor-periodic
|
||||||
- ciflow/inductor-rocm
|
- ciflow/inductor-rocm
|
||||||
- ciflow/linux-aarch64
|
- ciflow/linux-aarch64
|
||||||
|
|||||||
91
.github/scripts/generate_binary_build_matrix.py
vendored
91
.github/scripts/generate_binary_build_matrix.py
vendored
@ -11,24 +11,18 @@ architectures:
|
|||||||
* Latest XPU
|
* Latest XPU
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
SCRIPT_DIR = Path(__file__).absolute().parent
|
# NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this
|
||||||
REPO_ROOT = SCRIPT_DIR.parent.parent
|
|
||||||
|
|
||||||
|
|
||||||
CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"]
|
CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"]
|
||||||
CUDA_STABLE = "12.8"
|
CUDA_STABLE = "12.8"
|
||||||
CUDA_ARCHES_FULL_VERSION = {
|
CUDA_ARCHES_FULL_VERSION = {
|
||||||
"12.6": "12.6.3",
|
"12.6": "12.6.3",
|
||||||
"12.8": "12.8.1",
|
"12.8": "12.8.1",
|
||||||
"12.9": "12.9.1",
|
"12.9": "12.9.1",
|
||||||
"13.0": "13.0.0",
|
"13.0": "13.0.2",
|
||||||
}
|
}
|
||||||
CUDA_ARCHES_CUDNN_VERSION = {
|
CUDA_ARCHES_CUDNN_VERSION = {
|
||||||
"12.6": "9",
|
"12.6": "9",
|
||||||
@ -37,7 +31,8 @@ CUDA_ARCHES_CUDNN_VERSION = {
|
|||||||
"13.0": "9",
|
"13.0": "9",
|
||||||
}
|
}
|
||||||
|
|
||||||
ROCM_ARCHES = ["7.0", "7.1"]
|
# NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this
|
||||||
|
ROCM_ARCHES = ["6.4", "7.0"]
|
||||||
|
|
||||||
XPU_ARCHES = ["xpu"]
|
XPU_ARCHES = ["xpu"]
|
||||||
|
|
||||||
@ -142,48 +137,9 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Used by tools/nightly.py
|
|
||||||
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
|
|
||||||
NIGHTLY_SOURCE_MATRIX = {
|
|
||||||
"cpu": dict(
|
|
||||||
name="cpu",
|
|
||||||
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
|
|
||||||
supported_platforms=["Linux", "macOS", "Windows"],
|
|
||||||
accelerator="cpu",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
CUDA_NIGHTLY_SOURCE_MATRIX = {
|
|
||||||
f"cuda-{major}.{minor}": dict(
|
|
||||||
name=f"cuda-{major}.{minor}",
|
|
||||||
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu{major}{minor}",
|
|
||||||
supported_platforms=["Linux", "Windows"],
|
|
||||||
accelerator="cuda",
|
|
||||||
)
|
|
||||||
for major, minor in (map(int, version.split(".")) for version in CUDA_ARCHES)
|
|
||||||
}
|
|
||||||
ROCM_NIGHTLY_SOURCE_MATRIX = {
|
|
||||||
f"rocm-{major}.{minor}": dict(
|
|
||||||
name=f"rocm-{major}.{minor}",
|
|
||||||
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm{major}.{minor}",
|
|
||||||
supported_platforms=["Linux"],
|
|
||||||
accelerator="rocm",
|
|
||||||
)
|
|
||||||
for major, minor in (map(int, version.split(".")) for version in ROCM_ARCHES)
|
|
||||||
}
|
|
||||||
XPU_NIGHTLY_SOURCE_MATRIX = {
|
|
||||||
"xpu": dict(
|
|
||||||
name="xpu",
|
|
||||||
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/xpu",
|
|
||||||
supported_platforms=["Linux"],
|
|
||||||
accelerator="xpu",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
NIGHTLY_SOURCE_MATRIX.update(CUDA_NIGHTLY_SOURCE_MATRIX)
|
|
||||||
NIGHTLY_SOURCE_MATRIX.update(ROCM_NIGHTLY_SOURCE_MATRIX)
|
|
||||||
NIGHTLY_SOURCE_MATRIX.update(XPU_NIGHTLY_SOURCE_MATRIX)
|
|
||||||
|
|
||||||
|
|
||||||
def get_nccl_wheel_version(arch_version: str) -> str:
|
def get_nccl_wheel_version(arch_version: str) -> str:
|
||||||
|
import re
|
||||||
|
|
||||||
requirements = map(
|
requirements = map(
|
||||||
str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version])
|
str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version])
|
||||||
)
|
)
|
||||||
@ -191,14 +147,17 @@ def get_nccl_wheel_version(arch_version: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def read_nccl_pin(arch_version: str) -> str:
|
def read_nccl_pin(arch_version: str) -> str:
|
||||||
nccl_pin_path = (
|
from pathlib import Path
|
||||||
REPO_ROOT
|
|
||||||
/ ".ci"
|
nccl_pin_path = os.path.join(
|
||||||
/ "docker"
|
Path(__file__).absolute().parents[2],
|
||||||
/ "ci_commit_pins"
|
".ci",
|
||||||
/ f"nccl-cu{arch_version[:2]}.txt"
|
"docker",
|
||||||
|
"ci_commit_pins",
|
||||||
|
f"nccl-cu{arch_version[:2]}.txt",
|
||||||
)
|
)
|
||||||
return nccl_pin_path.read_text().strip()
|
with open(nccl_pin_path) as f:
|
||||||
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
def validate_nccl_dep_consistency(arch_version: str) -> None:
|
def validate_nccl_dep_consistency(arch_version: str) -> None:
|
||||||
@ -206,8 +165,7 @@ def validate_nccl_dep_consistency(arch_version: str) -> None:
|
|||||||
wheel_ver = get_nccl_wheel_version(arch_version)
|
wheel_ver = get_nccl_wheel_version(arch_version)
|
||||||
if not nccl_release_tag.startswith(f"v{wheel_ver}"):
|
if not nccl_release_tag.startswith(f"v{wheel_ver}"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{arch_version} NCCL release tag version {nccl_release_tag} "
|
f"{arch_version} NCCL release tag version {nccl_release_tag} does not correspond to wheel version {wheel_ver}"
|
||||||
f"does not correspond to wheel version {wheel_ver}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -454,14 +412,7 @@ def generate_wheels_matrix(
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
arch_version = ""
|
validate_nccl_dep_consistency("13.0")
|
||||||
for arch_version in CUDA_ARCHES:
|
validate_nccl_dep_consistency("12.9")
|
||||||
validate_nccl_dep_consistency(arch_version)
|
validate_nccl_dep_consistency("12.8")
|
||||||
del arch_version
|
validate_nccl_dep_consistency("12.6")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Used by tools/nightly.py
|
|
||||||
(SCRIPT_DIR / "nightly_source_matrix.json").write_text(
|
|
||||||
json.dumps(NIGHTLY_SOURCE_MATRIX, indent=4) + "\n"
|
|
||||||
)
|
|
||||||
|
|||||||
4
.github/workflows/_rocm-test.yml
vendored
4
.github/workflows/_rocm-test.yml
vendored
@ -97,8 +97,8 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
|
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
|
||||||
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
|
if [[ $ngpu -lt 4 ]]; then
|
||||||
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
|
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
29
.github/workflows/_xpu-test.yml
vendored
29
.github/workflows/_xpu-test.yml
vendored
@ -38,10 +38,6 @@ on:
|
|||||||
default: ""
|
default: ""
|
||||||
description: |
|
description: |
|
||||||
List of tests to include (empty string implies default list)
|
List of tests to include (empty string implies default list)
|
||||||
dashboard-tag:
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: ""
|
|
||||||
disable-monitor:
|
disable-monitor:
|
||||||
description: |
|
description: |
|
||||||
[Experimental] Disable utilization monitoring for tests.
|
[Experimental] Disable utilization monitoring for tests.
|
||||||
@ -62,11 +58,6 @@ on:
|
|||||||
required: false
|
required: false
|
||||||
type: number
|
type: number
|
||||||
default: 1
|
default: 1
|
||||||
secrets:
|
|
||||||
HUGGING_FACE_HUB_TOKEN:
|
|
||||||
required: false
|
|
||||||
description: |
|
|
||||||
HF Auth token to avoid rate limits when downloading models or datasets from hub
|
|
||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
contents: read
|
contents: read
|
||||||
@ -205,8 +196,6 @@ jobs:
|
|||||||
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
|
||||||
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
|
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
|
||||||
TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
|
TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
|
||||||
DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
|
|
||||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
|
||||||
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
|
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
|
||||||
run: |
|
run: |
|
||||||
# Fetch aws credential from IMDs
|
# Fetch aws credential from IMDs
|
||||||
@ -257,8 +246,6 @@ jobs:
|
|||||||
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \
|
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \
|
||||||
-e TESTS_TO_INCLUDE \
|
-e TESTS_TO_INCLUDE \
|
||||||
-e ZE_AFFINITY_MASK \
|
-e ZE_AFFINITY_MASK \
|
||||||
-e HUGGING_FACE_HUB_TOKEN \
|
|
||||||
-e DASHBOARD_TAG \
|
|
||||||
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
|
||||||
--ulimit stack=10485760:83886080 \
|
--ulimit stack=10485760:83886080 \
|
||||||
--ulimit core=0 \
|
--ulimit core=0 \
|
||||||
@ -344,21 +331,5 @@ jobs:
|
|||||||
if-no-files-found: ignore
|
if-no-files-found: ignore
|
||||||
path: ./**/core.[1-9]*
|
path: ./**/core.[1-9]*
|
||||||
|
|
||||||
- name: Authenticate with AWS
|
|
||||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
|
||||||
with:
|
|
||||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
|
|
||||||
# The max duration enforced by the server side
|
|
||||||
role-duration-seconds: 18000
|
|
||||||
aws-region: us-east-1
|
|
||||||
|
|
||||||
- name: Upload the benchmark results
|
|
||||||
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
|
|
||||||
with:
|
|
||||||
benchmark-results-dir: test/test-reports
|
|
||||||
dry-run: false
|
|
||||||
schema-version: v3
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Teardown XPU
|
- name: Teardown XPU
|
||||||
uses: ./.github/actions/teardown-xpu
|
uses: ./.github/actions/teardown-xpu
|
||||||
|
|||||||
2
.github/workflows/build-almalinux-images.yml
vendored
2
.github/workflows/build-almalinux-images.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
|||||||
runs-on: linux.9xlarge.ephemeral
|
runs-on: linux.9xlarge.ephemeral
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm7.0", "rocm7.1", "cpu"]
|
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"]
|
||||||
steps:
|
steps:
|
||||||
- name: Build docker image
|
- name: Build docker image
|
||||||
uses: pytorch/pytorch/.github/actions/binary-docker-build@main
|
uses: pytorch/pytorch/.github/actions/binary-docker-build@main
|
||||||
|
|||||||
2
.github/workflows/build-libtorch-images.yml
vendored
2
.github/workflows/build-libtorch-images.yml
vendored
@ -52,8 +52,8 @@ jobs:
|
|||||||
{ tag: "cuda12.9" },
|
{ tag: "cuda12.9" },
|
||||||
{ tag: "cuda12.8" },
|
{ tag: "cuda12.8" },
|
||||||
{ tag: "cuda12.6" },
|
{ tag: "cuda12.6" },
|
||||||
|
{ tag: "rocm6.4" },
|
||||||
{ tag: "rocm7.0" },
|
{ tag: "rocm7.0" },
|
||||||
{ tag: "rocm7.1" },
|
|
||||||
{ tag: "cpu" },
|
{ tag: "cpu" },
|
||||||
]
|
]
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
2
.github/workflows/build-magma-rocm-linux.yml
vendored
2
.github/workflows/build-magma-rocm-linux.yml
vendored
@ -34,7 +34,7 @@ jobs:
|
|||||||
id-token: write
|
id-token: write
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
rocm_version: ["71", "70"]
|
rocm_version: ["70", "64"]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout PyTorch
|
- name: Checkout PyTorch
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|||||||
2
.github/workflows/build-manywheel-images.yml
vendored
2
.github/workflows/build-manywheel-images.yml
vendored
@ -54,8 +54,8 @@ jobs:
|
|||||||
{ name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" },
|
{ name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" },
|
||||||
{ name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" },
|
{ name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" },
|
||||||
{ name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" },
|
{ name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" },
|
||||||
|
{ name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" },
|
||||||
{ name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" },
|
{ name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" },
|
||||||
{ name: "manylinux2_28-builder", tag: "rocm7.1", runner: "linux.9xlarge.ephemeral" },
|
|
||||||
{ name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" },
|
{ name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" },
|
||||||
{ name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" },
|
{ name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" },
|
||||||
{ name: "manylinux2_28-builder", tag: "xpu", runner: "linux.9xlarge.ephemeral" },
|
{ name: "manylinux2_28-builder", tag: "xpu", runner: "linux.9xlarge.ephemeral" },
|
||||||
|
|||||||
9
.github/workflows/build-triton-wheel.yml
vendored
9
.github/workflows/build-triton-wheel.yml
vendored
@ -55,7 +55,7 @@ jobs:
|
|||||||
docker-image: ["pytorch/manylinux2_28-builder:cpu"]
|
docker-image: ["pytorch/manylinux2_28-builder:cpu"]
|
||||||
include:
|
include:
|
||||||
- device: "rocm"
|
- device: "rocm"
|
||||||
rocm_version: "7.1"
|
rocm_version: "7.0"
|
||||||
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
||||||
- device: "cuda"
|
- device: "cuda"
|
||||||
rocm_version: ""
|
rocm_version: ""
|
||||||
@ -159,7 +159,12 @@ jobs:
|
|||||||
WITH_CLANG_LDD="--with-clang-ldd"
|
WITH_CLANG_LDD="--with-clang-ldd"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD"
|
if [[ "${BUILD_DEVICE}" == xpu ]]; then
|
||||||
|
docker exec -t "${container_name}" bash -c "dnf install -y gcc-toolset-13-gcc-c++"
|
||||||
|
docker exec -t "${container_name}" bash -c "source /opt/rh/gcc-toolset-13/enable && ${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE"
|
||||||
|
else
|
||||||
|
docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD"
|
||||||
|
fi
|
||||||
|
|
||||||
if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "xpu") ]]; then
|
if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "xpu") ]]; then
|
||||||
docker exec -t "${container_name}" bash -c "auditwheel repair --plat ${PLATFORM} //artifacts/*.whl"
|
docker exec -t "${container_name}" bash -c "auditwheel repair --plat ${PLATFORM} //artifacts/*.whl"
|
||||||
|
|||||||
7
.github/workflows/docker-builds.yml
vendored
7
.github/workflows/docker-builds.yml
vendored
@ -67,7 +67,6 @@ jobs:
|
|||||||
pytorch-linux-jammy-py3.12-halide,
|
pytorch-linux-jammy-py3.12-halide,
|
||||||
pytorch-linux-jammy-xpu-n-1-py3,
|
pytorch-linux-jammy-xpu-n-1-py3,
|
||||||
pytorch-linux-jammy-xpu-n-py3,
|
pytorch-linux-jammy-xpu-n-py3,
|
||||||
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
|
|
||||||
pytorch-linux-jammy-py3-clang18-asan,
|
pytorch-linux-jammy-py3-clang18-asan,
|
||||||
pytorch-linux-jammy-py3-clang12-onnx,
|
pytorch-linux-jammy-py3-clang12-onnx,
|
||||||
pytorch-linux-jammy-linter,
|
pytorch-linux-jammy-linter,
|
||||||
@ -77,11 +76,9 @@ jobs:
|
|||||||
pytorch-linux-noble-riscv64-py3.12-gcc14
|
pytorch-linux-noble-riscv64-py3.12-gcc14
|
||||||
]
|
]
|
||||||
include:
|
include:
|
||||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13
|
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
|
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||||
runner: linux.arm64.m7g.4xlarge
|
|
||||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
|
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
timeout-minutes: 600
|
timeout-minutes: 600
|
||||||
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358
|
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358
|
||||||
|
|||||||
1
.github/workflows/docker-release.yml
vendored
1
.github/workflows/docker-release.yml
vendored
@ -8,7 +8,6 @@ on:
|
|||||||
- docker.Makefile
|
- docker.Makefile
|
||||||
- .github/workflows/docker-release.yml
|
- .github/workflows/docker-release.yml
|
||||||
- .github/scripts/generate_docker_release_matrix.py
|
- .github/scripts/generate_docker_release_matrix.py
|
||||||
- .github/scripts/generate_binary_build_matrix.py
|
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- nightly
|
- nightly
|
||||||
|
|||||||
236
.github/workflows/generated-linux-binary-libtorch-nightly.yml
generated
vendored
236
.github/workflows/generated-linux-binary-libtorch-nightly.yml
generated
vendored
@ -384,6 +384,124 @@ jobs:
|
|||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
uses: ./.github/workflows/_binary-upload.yml
|
uses: ./.github/workflows/_binary-upload.yml
|
||||||
|
|
||||||
|
libtorch-rocm6_4-shared-with-deps-release-build:
|
||||||
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
|
uses: ./.github/workflows/_binary-build-linux.yml
|
||||||
|
needs: get-label-type
|
||||||
|
with:
|
||||||
|
PYTORCH_ROOT: /pytorch
|
||||||
|
PACKAGE_TYPE: libtorch
|
||||||
|
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||||
|
# favor of GPU_ARCH_VERSION
|
||||||
|
DESIRED_CUDA: rocm6.4
|
||||||
|
GPU_ARCH_VERSION: "6.4"
|
||||||
|
GPU_ARCH_TYPE: rocm
|
||||||
|
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||||
|
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||||
|
LIBTORCH_CONFIG: release
|
||||||
|
LIBTORCH_VARIANT: shared-with-deps
|
||||||
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
timeout-minutes: 300
|
||||||
|
build_name: libtorch-rocm6_4-shared-with-deps-release
|
||||||
|
build_environment: linux-binary-libtorch
|
||||||
|
secrets:
|
||||||
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
libtorch-rocm6_4-shared-with-deps-release-test: # Testing
|
||||||
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
|
needs:
|
||||||
|
- libtorch-rocm6_4-shared-with-deps-release-build
|
||||||
|
- get-label-type
|
||||||
|
runs-on: linux.rocm.gpu.mi250
|
||||||
|
timeout-minutes: 240
|
||||||
|
env:
|
||||||
|
PYTORCH_ROOT: /pytorch
|
||||||
|
PACKAGE_TYPE: libtorch
|
||||||
|
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||||
|
# favor of GPU_ARCH_VERSION
|
||||||
|
DESIRED_CUDA: rocm6.4
|
||||||
|
GPU_ARCH_VERSION: "6.4"
|
||||||
|
GPU_ARCH_TYPE: rocm
|
||||||
|
SKIP_ALL_TESTS: 1
|
||||||
|
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||||
|
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||||
|
LIBTORCH_CONFIG: release
|
||||||
|
LIBTORCH_VARIANT: shared-with-deps
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
contents: read
|
||||||
|
steps:
|
||||||
|
- name: Setup ROCm
|
||||||
|
uses: ./.github/actions/setup-rocm
|
||||||
|
- uses: actions/download-artifact@v4.1.7
|
||||||
|
name: Download Build Artifacts
|
||||||
|
with:
|
||||||
|
name: libtorch-rocm6_4-shared-with-deps-release
|
||||||
|
path: "${{ runner.temp }}/artifacts/"
|
||||||
|
- name: Checkout PyTorch
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||||
|
submodules: recursive
|
||||||
|
path: pytorch
|
||||||
|
show-progress: false
|
||||||
|
- name: Clean PyTorch checkout
|
||||||
|
run: |
|
||||||
|
# Remove any artifacts from the previous checkouts
|
||||||
|
git clean -fxd
|
||||||
|
working-directory: pytorch
|
||||||
|
- name: ROCm set GPU_FLAG
|
||||||
|
run: |
|
||||||
|
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
|
||||||
|
- name: configure aws credentials
|
||||||
|
id: aws_creds
|
||||||
|
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
|
||||||
|
uses: aws-actions/configure-aws-credentials@v4
|
||||||
|
with:
|
||||||
|
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||||
|
aws-region: us-east-1
|
||||||
|
role-duration-seconds: 18000
|
||||||
|
- name: Calculate docker image
|
||||||
|
id: calculate-docker-image
|
||||||
|
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||||
|
with:
|
||||||
|
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
|
||||||
|
docker-image-name: libtorch-cxx11-builder
|
||||||
|
custom-tag-prefix: rocm6.4
|
||||||
|
docker-build-dir: .ci/docker
|
||||||
|
working-directory: pytorch
|
||||||
|
- name: Pull Docker image
|
||||||
|
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||||
|
with:
|
||||||
|
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||||
|
- name: Test Pytorch binary
|
||||||
|
uses: ./pytorch/.github/actions/test-pytorch-binary
|
||||||
|
env:
|
||||||
|
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||||
|
- name: Teardown ROCm
|
||||||
|
uses: ./.github/actions/teardown-rocm
|
||||||
|
libtorch-rocm6_4-shared-with-deps-release-upload: # Uploading
|
||||||
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
contents: read
|
||||||
|
needs: libtorch-rocm6_4-shared-with-deps-release-test
|
||||||
|
with:
|
||||||
|
PYTORCH_ROOT: /pytorch
|
||||||
|
PACKAGE_TYPE: libtorch
|
||||||
|
# TODO: This is a legacy variable that we eventually want to get rid of in
|
||||||
|
# favor of GPU_ARCH_VERSION
|
||||||
|
DESIRED_CUDA: rocm6.4
|
||||||
|
GPU_ARCH_VERSION: "6.4"
|
||||||
|
GPU_ARCH_TYPE: rocm
|
||||||
|
DOCKER_IMAGE: libtorch-cxx11-builder
|
||||||
|
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
|
||||||
|
LIBTORCH_CONFIG: release
|
||||||
|
LIBTORCH_VARIANT: shared-with-deps
|
||||||
|
build_name: libtorch-rocm6_4-shared-with-deps-release
|
||||||
|
secrets:
|
||||||
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
uses: ./.github/workflows/_binary-upload.yml
|
||||||
|
|
||||||
libtorch-rocm7_0-shared-with-deps-release-build:
|
libtorch-rocm7_0-shared-with-deps-release-build:
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
if: ${{ github.repository_owner == 'pytorch' }}
|
||||||
uses: ./.github/workflows/_binary-build-linux.yml
|
uses: ./.github/workflows/_binary-build-linux.yml
|
||||||
@ -501,121 +619,3 @@ jobs:
|
|||||||
secrets:
|
secrets:
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
uses: ./.github/workflows/_binary-upload.yml
|
uses: ./.github/workflows/_binary-upload.yml
|
||||||
|
|
||||||
libtorch-rocm7_1-shared-with-deps-release-build:
|
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
|
||||||
uses: ./.github/workflows/_binary-build-linux.yml
|
|
||||||
needs: get-label-type
|
|
||||||
with:
|
|
||||||
PYTORCH_ROOT: /pytorch
|
|
||||||
PACKAGE_TYPE: libtorch
|
|
||||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
|
||||||
# favor of GPU_ARCH_VERSION
|
|
||||||
DESIRED_CUDA: rocm7.1
|
|
||||||
GPU_ARCH_VERSION: "7.1"
|
|
||||||
GPU_ARCH_TYPE: rocm
|
|
||||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
|
||||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
|
|
||||||
LIBTORCH_CONFIG: release
|
|
||||||
LIBTORCH_VARIANT: shared-with-deps
|
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
|
||||||
timeout-minutes: 300
|
|
||||||
build_name: libtorch-rocm7_1-shared-with-deps-release
|
|
||||||
build_environment: linux-binary-libtorch
|
|
||||||
secrets:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
libtorch-rocm7_1-shared-with-deps-release-test: # Testing
|
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
|
||||||
needs:
|
|
||||||
- libtorch-rocm7_1-shared-with-deps-release-build
|
|
||||||
- get-label-type
|
|
||||||
runs-on: linux.rocm.gpu.mi250
|
|
||||||
timeout-minutes: 240
|
|
||||||
env:
|
|
||||||
PYTORCH_ROOT: /pytorch
|
|
||||||
PACKAGE_TYPE: libtorch
|
|
||||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
|
||||||
# favor of GPU_ARCH_VERSION
|
|
||||||
DESIRED_CUDA: rocm7.1
|
|
||||||
GPU_ARCH_VERSION: "7.1"
|
|
||||||
GPU_ARCH_TYPE: rocm
|
|
||||||
SKIP_ALL_TESTS: 1
|
|
||||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
|
||||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
|
|
||||||
LIBTORCH_CONFIG: release
|
|
||||||
LIBTORCH_VARIANT: shared-with-deps
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
steps:
|
|
||||||
- name: Setup ROCm
|
|
||||||
uses: ./.github/actions/setup-rocm
|
|
||||||
- uses: actions/download-artifact@v4.1.7
|
|
||||||
name: Download Build Artifacts
|
|
||||||
with:
|
|
||||||
name: libtorch-rocm7_1-shared-with-deps-release
|
|
||||||
path: "${{ runner.temp }}/artifacts/"
|
|
||||||
- name: Checkout PyTorch
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
|
||||||
submodules: recursive
|
|
||||||
path: pytorch
|
|
||||||
show-progress: false
|
|
||||||
- name: Clean PyTorch checkout
|
|
||||||
run: |
|
|
||||||
# Remove any artifacts from the previous checkouts
|
|
||||||
git clean -fxd
|
|
||||||
working-directory: pytorch
|
|
||||||
- name: ROCm set GPU_FLAG
|
|
||||||
run: |
|
|
||||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
|
|
||||||
- name: configure aws credentials
|
|
||||||
id: aws_creds
|
|
||||||
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
|
|
||||||
uses: aws-actions/configure-aws-credentials@v4
|
|
||||||
with:
|
|
||||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
|
||||||
aws-region: us-east-1
|
|
||||||
role-duration-seconds: 18000
|
|
||||||
- name: Calculate docker image
|
|
||||||
id: calculate-docker-image
|
|
||||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
|
||||||
with:
|
|
||||||
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
|
|
||||||
docker-image-name: libtorch-cxx11-builder
|
|
||||||
custom-tag-prefix: rocm7.1
|
|
||||||
docker-build-dir: .ci/docker
|
|
||||||
working-directory: pytorch
|
|
||||||
- name: Pull Docker image
|
|
||||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
|
||||||
with:
|
|
||||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
|
||||||
- name: Test Pytorch binary
|
|
||||||
uses: ./pytorch/.github/actions/test-pytorch-binary
|
|
||||||
env:
|
|
||||||
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
|
||||||
- name: Teardown ROCm
|
|
||||||
uses: ./.github/actions/teardown-rocm
|
|
||||||
libtorch-rocm7_1-shared-with-deps-release-upload: # Uploading
|
|
||||||
if: ${{ github.repository_owner == 'pytorch' }}
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
needs: libtorch-rocm7_1-shared-with-deps-release-test
|
|
||||||
with:
|
|
||||||
PYTORCH_ROOT: /pytorch
|
|
||||||
PACKAGE_TYPE: libtorch
|
|
||||||
# TODO: This is a legacy variable that we eventually want to get rid of in
|
|
||||||
# favor of GPU_ARCH_VERSION
|
|
||||||
DESIRED_CUDA: rocm7.1
|
|
||||||
GPU_ARCH_VERSION: "7.1"
|
|
||||||
GPU_ARCH_TYPE: rocm
|
|
||||||
DOCKER_IMAGE: libtorch-cxx11-builder
|
|
||||||
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
|
|
||||||
LIBTORCH_CONFIG: release
|
|
||||||
LIBTORCH_VARIANT: shared-with-deps
|
|
||||||
build_name: libtorch-rocm7_1-shared-with-deps-release
|
|
||||||
secrets:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
uses: ./.github/workflows/_binary-upload.yml
|
|
||||||
|
|||||||
1610
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
1610
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
File diff suppressed because it is too large
Load Diff
@ -72,7 +72,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
build-environment: linux-jammy-aarch64-py3.10
|
build-environment: linux-jammy-aarch64-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
|
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },
|
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },
|
||||||
|
|||||||
148
.github/workflows/inductor-perf-test-nightly-xpu.yml
vendored
148
.github/workflows/inductor-perf-test-nightly-xpu.yml
vendored
@ -1,148 +0,0 @@
|
|||||||
name: inductor-perf-nightly-xpu
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- ciflow/inductor-perf-test-nightly-xpu/*
|
|
||||||
schedule:
|
|
||||||
- cron: 30 17 * * *
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
training:
|
|
||||||
description: Run training (on by default)?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: true
|
|
||||||
inference:
|
|
||||||
description: Run inference (on by default)?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: true
|
|
||||||
default:
|
|
||||||
description: Run inductor_default?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
dynamic:
|
|
||||||
description: Run inductor_dynamic_shapes?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
cppwrapper:
|
|
||||||
description: Run inductor_cpp_wrapper?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
cudagraphs:
|
|
||||||
description: Run inductor_cudagraphs?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
freezing_cudagraphs:
|
|
||||||
description: Run inductor_cudagraphs with freezing for inference?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
aotinductor:
|
|
||||||
description: Run aot_inductor for inference?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
maxautotune:
|
|
||||||
description: Run inductor_max_autotune?
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
benchmark_configs:
|
|
||||||
description: The list of configs used the benchmark
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: inductor_huggingface_perf,inductor_timm_perf,inductor_torchbench_perf,cachebench
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions: read-all
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
get-label-type:
|
|
||||||
name: get-label-type
|
|
||||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
|
||||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
|
||||||
with:
|
|
||||||
triggering_actor: ${{ github.triggering_actor }}
|
|
||||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
|
||||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
|
||||||
curr_ref_type: ${{ github.ref_type }}
|
|
||||||
opt_out_experiments: lf
|
|
||||||
|
|
||||||
xpu-n-py3_10-inductor-benchmark-build:
|
|
||||||
name: xpu-n-py3.10-inductor-benchmark
|
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
|
||||||
needs: get-label-type
|
|
||||||
with:
|
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
|
||||||
build-environment: linux-jammy-xpu-n-py3.10
|
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
|
|
||||||
runner: linux.c7i.12xlarge
|
|
||||||
test-matrix: |
|
|
||||||
{ include: [
|
|
||||||
{ config: "inductor_huggingface_perf_xpu", shard: 1, num_shards: 5, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_huggingface_perf_xpu", shard: 2, num_shards: 5, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_huggingface_perf_xpu", shard: 3, num_shards: 5, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_huggingface_perf_xpu", shard: 4, num_shards: 5, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_huggingface_perf_xpu", shard: 5, num_shards: 5, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_timm_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_timm_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_timm_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_timm_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_timm_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_timm_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_torchbench_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_torchbench_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_torchbench_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_torchbench_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_torchbench_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
{ config: "inductor_torchbench_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
|
|
||||||
]}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
xpu-n-py3_10-inductor-benchmark-test-nightly:
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
if: github.event_name != 'workflow_dispatch'
|
|
||||||
name: xpu-n-py3.10-inductor-benchmark
|
|
||||||
uses: ./.github/workflows/_xpu-test.yml
|
|
||||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-xpu-n-py3.10
|
|
||||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
|
|
||||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
|
||||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
|
||||||
timeout-minutes: 720
|
|
||||||
# Disable monitor in perf tests for more investigation
|
|
||||||
disable-monitor: true
|
|
||||||
monitor-log-interval: 10
|
|
||||||
monitor-data-collect-interval: 2
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
xpu-n-py3_10-inductor-benchmark-test:
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
if: github.event_name == 'workflow_dispatch'
|
|
||||||
name: xpu-n-py3.10-inductor-test
|
|
||||||
uses: ./.github/workflows/_xpu-test.yml
|
|
||||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-xpu-n-py3.10
|
|
||||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
|
||||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
|
||||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
|
||||||
timeout-minutes: 720
|
|
||||||
disable-monitor: false
|
|
||||||
monitor-log-interval: 15
|
|
||||||
monitor-data-collect-interval: 4
|
|
||||||
secrets: inherit
|
|
||||||
3
.github/workflows/inductor-rocm.yml
vendored
3
.github/workflows/inductor-rocm.yml
vendored
@ -1,10 +1,9 @@
|
|||||||
name: inductor-rocm
|
name: inductor-rocm
|
||||||
|
|
||||||
on:
|
on:
|
||||||
schedule:
|
|
||||||
- cron: 0 */3 * * *
|
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
|
- main
|
||||||
- release/*
|
- release/*
|
||||||
tags:
|
tags:
|
||||||
- ciflow/inductor-rocm/*
|
- ciflow/inductor-rocm/*
|
||||||
|
|||||||
8
.github/workflows/inductor-unittest.yml
vendored
8
.github/workflows/inductor-unittest.yml
vendored
@ -115,10 +115,10 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||||
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" },
|
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
|||||||
14
.github/workflows/inductor.yml
vendored
14
.github/workflows/inductor.yml
vendored
@ -84,13 +84,13 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" },
|
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
|
||||||
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
|
||||||
]}
|
]}
|
||||||
build-additional-packages: "vision audio torchao"
|
build-additional-packages: "vision audio torchao"
|
||||||
|
|||||||
15
.github/workflows/lint.yml
vendored
15
.github/workflows/lint.yml
vendored
@ -76,12 +76,11 @@ jobs:
|
|||||||
|
|
||||||
# NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes
|
# NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes
|
||||||
# fails to find types when it should
|
# fails to find types when it should
|
||||||
# NOTE: We should be able to disable this and consolidate with Pyrefly
|
lintrunner-mypy:
|
||||||
lintrunner-pyrefly:
|
|
||||||
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
||||||
name: lintrunner-pyrefly-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
|
name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
|
||||||
needs: [get-label-type, get-changed-files]
|
needs: [get-label-type, get-changed-files]
|
||||||
# Only run if there are changed files relevant to pyrefly
|
# Only run if there are changed files relevant to mypy
|
||||||
if: |
|
if: |
|
||||||
github.repository_owner == 'pytorch' && (
|
github.repository_owner == 'pytorch' && (
|
||||||
needs.get-changed-files.outputs.changed-files == '*' ||
|
needs.get-changed-files.outputs.changed-files == '*' ||
|
||||||
@ -99,8 +98,8 @@ jobs:
|
|||||||
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
|
||||||
script: |
|
script: |
|
||||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||||
echo "Running pyrefly"
|
echo "Running mypy"
|
||||||
ADDITIONAL_LINTRUNNER_ARGS="--take PYREFLY --all-files" .github/scripts/lintrunner.sh
|
ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
|
||||||
|
|
||||||
lintrunner-noclang:
|
lintrunner-noclang:
|
||||||
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
|
||||||
@ -119,9 +118,9 @@ jobs:
|
|||||||
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
|
||||||
echo "Running all other linters"
|
echo "Running all other linters"
|
||||||
if [ "$CHANGED_FILES" = '*' ]; then
|
if [ "$CHANGED_FILES" = '*' ]; then
|
||||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY --all-files" .github/scripts/lintrunner.sh
|
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
|
||||||
else
|
else
|
||||||
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
|
||||||
fi
|
fi
|
||||||
|
|
||||||
quick-checks:
|
quick-checks:
|
||||||
|
|||||||
2
.github/workflows/linux-aarch64.yml
vendored
2
.github/workflows/linux-aarch64.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||||
build-environment: linux-jammy-aarch64-py3.10
|
build-environment: linux-jammy-aarch64-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
|
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
|
|||||||
2
.github/workflows/nightly.yml
vendored
2
.github/workflows/nightly.yml
vendored
@ -41,7 +41,7 @@ jobs:
|
|||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
|
||||||
build-environment: linux-jammy-py3.10-gcc11
|
build-environment: linux-jammy-py3.10-gcc11
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
|
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|||||||
2
.github/workflows/operator_benchmark.yml
vendored
2
.github/workflows/operator_benchmark.yml
vendored
@ -60,7 +60,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
build-environment: linux-jammy-aarch64-py3.10
|
build-environment: linux-jammy-aarch64-py3.10
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
|
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
||||||
|
|||||||
8
.github/workflows/pull.yml
vendored
8
.github/workflows/pull.yml
vendored
@ -66,10 +66,10 @@ jobs:
|
|||||||
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
@ -167,8 +167,8 @@ jobs:
|
|||||||
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx
|
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" },
|
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
|||||||
3
.github/workflows/rocm.yml
vendored
3
.github/workflows/rocm.yml
vendored
@ -3,14 +3,13 @@ name: rocm
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
|
- main
|
||||||
- release/*
|
- release/*
|
||||||
tags:
|
tags:
|
||||||
- ciflow/rocm/*
|
- ciflow/rocm/*
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: 29 8 * * * # about 1:29am PDT
|
- cron: 29 8 * * * # about 1:29am PDT
|
||||||
- cron: 0 */3 * * *
|
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||||
|
|||||||
3
.github/workflows/trunk.yml
vendored
3
.github/workflows/trunk.yml
vendored
@ -204,7 +204,6 @@ jobs:
|
|||||||
{ include: [
|
{ include: [
|
||||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
|
||||||
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
|
|
||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
@ -222,7 +221,7 @@ jobs:
|
|||||||
build-environment: linux-jammy-rocm-py3.10
|
build-environment: linux-jammy-rocm-py3.10
|
||||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
|
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
inductor-build:
|
inductor-build:
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -127,7 +127,6 @@ torch/test/
|
|||||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||||
torch/version.py
|
torch/version.py
|
||||||
torch/_inductor/kernel/vendored_templates/*
|
|
||||||
minifier_launcher.py
|
minifier_launcher.py
|
||||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||||
@ -144,7 +143,6 @@ scripts/release_notes/*.json
|
|||||||
sccache-stats*.json
|
sccache-stats*.json
|
||||||
lint.json
|
lint.json
|
||||||
merge_record.json
|
merge_record.json
|
||||||
.github/scripts/nightly_source_matrix.json
|
|
||||||
|
|
||||||
# These files get copied over on invoking setup.py
|
# These files get copied over on invoking setup.py
|
||||||
torchgen/packaged/*
|
torchgen/packaged/*
|
||||||
@ -399,4 +397,3 @@ CLAUDE.local.md
|
|||||||
/test_*.py
|
/test_*.py
|
||||||
/debug_*.py
|
/debug_*.py
|
||||||
CLAUDE_CONTEXT/
|
CLAUDE_CONTEXT/
|
||||||
/.claude/settings.local.json
|
|
||||||
|
|||||||
@ -121,6 +121,94 @@ command = [
|
|||||||
]
|
]
|
||||||
is_formatter = true
|
is_formatter = true
|
||||||
|
|
||||||
|
[[linter]]
|
||||||
|
code = 'MYPY'
|
||||||
|
include_patterns = [
|
||||||
|
'setup.py',
|
||||||
|
'functorch/dim/**/*.py',
|
||||||
|
'torch/**/*.py',
|
||||||
|
'torch/**/*.pyi',
|
||||||
|
'caffe2/**/*.py',
|
||||||
|
'caffe2/**/*.pyi',
|
||||||
|
'test/test_bundled_images.py',
|
||||||
|
'test/test_bundled_inputs.py',
|
||||||
|
'test/test_complex.py',
|
||||||
|
'test/test_datapipe.py',
|
||||||
|
'test/test_futures.py',
|
||||||
|
'test/test_numpy_interop.py',
|
||||||
|
'test/test_torch.py',
|
||||||
|
'test/test_type_hints.py',
|
||||||
|
'test/test_type_info.py',
|
||||||
|
'test/test_utils.py',
|
||||||
|
]
|
||||||
|
exclude_patterns = [
|
||||||
|
'**/fb/**',
|
||||||
|
]
|
||||||
|
command = [
|
||||||
|
'python3',
|
||||||
|
'tools/linter/adapters/mypy_linter.py',
|
||||||
|
'--config=mypy.ini',
|
||||||
|
'--',
|
||||||
|
'@{{PATHSFILE}}'
|
||||||
|
]
|
||||||
|
init_command = [
|
||||||
|
'python3',
|
||||||
|
'tools/linter/adapters/pip_init.py',
|
||||||
|
'--dry-run={{DRYRUN}}',
|
||||||
|
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
||||||
|
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||||
|
'expecttest==0.3.0',
|
||||||
|
'mypy==1.16.0',
|
||||||
|
'sympy==1.13.3',
|
||||||
|
'types-requests==2.27.25',
|
||||||
|
'types-pyyaml==6.0.2',
|
||||||
|
'types-tabulate==0.8.8',
|
||||||
|
'types-protobuf==5.29.1.20250403',
|
||||||
|
'types-setuptools==79.0.0.20250422',
|
||||||
|
'types-jinja2==2.11.9',
|
||||||
|
'types-colorama==0.4.6',
|
||||||
|
'filelock==3.18.0',
|
||||||
|
'junitparser==2.1.1',
|
||||||
|
'rich==14.1.0',
|
||||||
|
'pyyaml==6.0.2',
|
||||||
|
'optree==0.13.0',
|
||||||
|
'dataclasses-json==0.6.7',
|
||||||
|
'pandas==2.2.3',
|
||||||
|
]
|
||||||
|
|
||||||
|
[[linter]]
|
||||||
|
code = 'MYPYSTRICT'
|
||||||
|
include_patterns = [
|
||||||
|
'.github/**/*.py',
|
||||||
|
'benchmarks/instruction_counts/**/*.py',
|
||||||
|
'tools/**/*.py',
|
||||||
|
'torchgen/**/*.py',
|
||||||
|
'torch/utils/_pytree.py',
|
||||||
|
'torch/utils/_cxx_pytree.py',
|
||||||
|
'torch/utils/benchmark/utils/common.py',
|
||||||
|
'torch/utils/benchmark/utils/timer.py',
|
||||||
|
'torch/utils/benchmark/utils/valgrind_wrapper/**/*.py',
|
||||||
|
]
|
||||||
|
exclude_patterns = [
|
||||||
|
# (linbinyu) copied from internal repo
|
||||||
|
'**/fb/**',
|
||||||
|
'tools/code_analyzer/gen_operators_yaml.py',
|
||||||
|
'tools/dynamo/verify_dynamo.py',
|
||||||
|
'tools/gen_vulkan_spv.py',
|
||||||
|
'tools/test/gen_operators_yaml_test.py',
|
||||||
|
'tools/test/gen_oplist_test.py',
|
||||||
|
'tools/test/test_selective_build.py',
|
||||||
|
'tools/experimental/torchfuzz/**',
|
||||||
|
]
|
||||||
|
command = [
|
||||||
|
'python3',
|
||||||
|
'tools/linter/adapters/mypy_linter.py',
|
||||||
|
'--config=mypy-strict.ini',
|
||||||
|
'--code=MYPYSTRICT',
|
||||||
|
'--',
|
||||||
|
'@{{PATHSFILE}}'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
[[linter]]
|
[[linter]]
|
||||||
code = 'PYREFLY'
|
code = 'PYREFLY'
|
||||||
@ -142,9 +230,7 @@ init_command = [
|
|||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/pip_init.py',
|
'tools/linter/adapters/pip_init.py',
|
||||||
'--dry-run={{DRYRUN}}',
|
'--dry-run={{DRYRUN}}',
|
||||||
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||||
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
|
|
||||||
'numpy==2.3.4 ; python_version >= "3.14"',
|
|
||||||
'expecttest==0.3.0',
|
'expecttest==0.3.0',
|
||||||
'pyrefly==0.36.2',
|
'pyrefly==0.36.2',
|
||||||
'sympy==1.13.3',
|
'sympy==1.13.3',
|
||||||
@ -212,6 +298,7 @@ exclude_patterns = [
|
|||||||
'**/*pb.h',
|
'**/*pb.h',
|
||||||
'**/*inl.h',
|
'**/*inl.h',
|
||||||
'aten/src/ATen/cpu/FlushDenormal.cpp',
|
'aten/src/ATen/cpu/FlushDenormal.cpp',
|
||||||
|
'aten/src/ATen/cpu/Utils.cpp',
|
||||||
'aten/src/ATen/cpu/vml.h',
|
'aten/src/ATen/cpu/vml.h',
|
||||||
'aten/src/ATen/CPUFixedAllocator.h',
|
'aten/src/ATen/CPUFixedAllocator.h',
|
||||||
'aten/src/ATen/Parallel*.h',
|
'aten/src/ATen/Parallel*.h',
|
||||||
@ -230,6 +317,8 @@ exclude_patterns = [
|
|||||||
'c10/util/win32-headers.h',
|
'c10/util/win32-headers.h',
|
||||||
'c10/test/**/*.h',
|
'c10/test/**/*.h',
|
||||||
'third_party/**/*',
|
'third_party/**/*',
|
||||||
|
'torch/csrc/api/include/torch/nn/modules/common.h',
|
||||||
|
'torch/csrc/api/include/torch/linalg.h',
|
||||||
'torch/csrc/autograd/generated/**',
|
'torch/csrc/autograd/generated/**',
|
||||||
'torch/csrc/distributed/**/*.cu',
|
'torch/csrc/distributed/**/*.cu',
|
||||||
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
|
'torch/csrc/distributed/c10d/WinSockUtils.hpp',
|
||||||
@ -241,6 +330,7 @@ exclude_patterns = [
|
|||||||
'torch/csrc/utils/generated_serialization_types.h',
|
'torch/csrc/utils/generated_serialization_types.h',
|
||||||
'torch/csrc/utils/pythoncapi_compat.h',
|
'torch/csrc/utils/pythoncapi_compat.h',
|
||||||
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
|
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
|
||||||
|
'aten/src/ATen/ExpandBase.h',
|
||||||
]
|
]
|
||||||
init_command = [
|
init_command = [
|
||||||
'python3',
|
'python3',
|
||||||
|
|||||||
@ -234,17 +234,7 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON)
|
|||||||
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
|
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
|
||||||
option(USE_LSAN "Use Leak Sanitizer" OFF)
|
option(USE_LSAN "Use Leak Sanitizer" OFF)
|
||||||
option(USE_TSAN "Use Thread Sanitizer" OFF)
|
option(USE_TSAN "Use Thread Sanitizer" OFF)
|
||||||
|
|
||||||
# Track whether USE_CUDA was explicitly set by the user (before option() is called)
|
|
||||||
# If USE_CUDA is already defined in cache, it means user explicitly set it
|
|
||||||
if(DEFINED CACHE{USE_CUDA})
|
|
||||||
set(_USE_CUDA_EXPLICITLY_SET TRUE)
|
|
||||||
else()
|
|
||||||
set(_USE_CUDA_EXPLICITLY_SET FALSE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
option(USE_CUDA "Use CUDA" ON)
|
option(USE_CUDA "Use CUDA" ON)
|
||||||
|
|
||||||
option(USE_XPU "Use XPU" ON)
|
option(USE_XPU "Use XPU" ON)
|
||||||
cmake_dependent_option(
|
cmake_dependent_option(
|
||||||
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
||||||
|
|||||||
@ -11,6 +11,7 @@ aspects of contributing to PyTorch.
|
|||||||
<!-- toc -->
|
<!-- toc -->
|
||||||
|
|
||||||
- [Developing PyTorch](#developing-pytorch)
|
- [Developing PyTorch](#developing-pytorch)
|
||||||
|
- [Setup the development environment](#setup-the-development-environment)
|
||||||
- [Tips and Debugging](#tips-and-debugging)
|
- [Tips and Debugging](#tips-and-debugging)
|
||||||
- [Nightly Checkout & Pull](#nightly-checkout--pull)
|
- [Nightly Checkout & Pull](#nightly-checkout--pull)
|
||||||
- [Codebase structure](#codebase-structure)
|
- [Codebase structure](#codebase-structure)
|
||||||
@ -18,7 +19,7 @@ aspects of contributing to PyTorch.
|
|||||||
- [Python Unit Testing](#python-unit-testing)
|
- [Python Unit Testing](#python-unit-testing)
|
||||||
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
||||||
- [Local linting](#local-linting)
|
- [Local linting](#local-linting)
|
||||||
- [Running `pyrefly`](#running-pyrefly)
|
- [Running `mypy`](#running-mypy)
|
||||||
- [C++ Unit Testing](#c-unit-testing)
|
- [C++ Unit Testing](#c-unit-testing)
|
||||||
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
||||||
- [Merging your Change](#merging-your-change)
|
- [Merging your Change](#merging-your-change)
|
||||||
@ -66,6 +67,23 @@ aspects of contributing to PyTorch.
|
|||||||
|
|
||||||
Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions.
|
Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions.
|
||||||
|
|
||||||
|
### Setup the development environment
|
||||||
|
|
||||||
|
First, you need to [fork the PyTorch project on GitHub](https://github.com/pytorch/pytorch/fork) and follow the instructions at [Connecting to GitHub with SSH](https://docs.github.com/en/authentication/connecting-to-github-with-ssh) to setup your SSH authentication credentials.
|
||||||
|
|
||||||
|
Then clone the PyTorch project and setup the development environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone git@github.com:<USERNAME>/pytorch.git
|
||||||
|
cd pytorch
|
||||||
|
git remote add upstream git@github.com:pytorch/pytorch.git
|
||||||
|
|
||||||
|
make setup-env
|
||||||
|
# Or run `make setup-env-cuda` for pre-built CUDA binaries
|
||||||
|
# Or run `make setup-env-rocm` for pre-built ROCm binaries
|
||||||
|
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
|
||||||
|
```
|
||||||
|
|
||||||
### Tips and Debugging
|
### Tips and Debugging
|
||||||
|
|
||||||
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.
|
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.
|
||||||
@ -281,7 +299,7 @@ dependencies as well as the nightly binaries into the repo directory.
|
|||||||
**Prerequisites**:
|
**Prerequisites**:
|
||||||
The following packages should be installed with `pip`:
|
The following packages should be installed with `pip`:
|
||||||
- `expecttest` and `hypothesis` - required to run tests
|
- `expecttest` and `hypothesis` - required to run tests
|
||||||
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
|
- `mypy` - recommended for linting
|
||||||
- `pytest` - recommended to run tests more selectively
|
- `pytest` - recommended to run tests more selectively
|
||||||
Running
|
Running
|
||||||
```
|
```
|
||||||
@ -350,32 +368,15 @@ make lint
|
|||||||
|
|
||||||
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
||||||
|
|
||||||
#### Running `pyrefly`
|
#### Running `mypy`
|
||||||
|
|
||||||
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
|
`mypy` is an optional static type checker for Python. We have multiple `mypy`
|
||||||
|
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
|
||||||
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
|
|
||||||
|
|
||||||
**Getting Started with Pyrefly:**
|
|
||||||
|
|
||||||
To run type checking on the PyTorch codebase:
|
|
||||||
```bash
|
|
||||||
pyrefly check
|
|
||||||
```
|
|
||||||
|
|
||||||
For more detailed error information with summaries:
|
|
||||||
```bash
|
|
||||||
pyrefly check --summarize-errors
|
|
||||||
```
|
|
||||||
|
|
||||||
**Learn More:**
|
|
||||||
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
|
|
||||||
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
|
|
||||||
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
|
|
||||||
|
|
||||||
See [Guide for adding type annotations to
|
See [Guide for adding type annotations to
|
||||||
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
||||||
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
|
for more information on how to set up `mypy` and tackle type annotation
|
||||||
|
tasks.
|
||||||
|
|
||||||
### C++ Unit Testing
|
### C++ Unit Testing
|
||||||
|
|
||||||
|
|||||||
20
SECURITY.md
20
SECURITY.md
@ -1,7 +1,7 @@
|
|||||||
# Security Policy
|
# Security Policy
|
||||||
|
|
||||||
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
||||||
- [**Using PyTorch Securely**](#using-pytorch-securely)
|
- [**Using Pytorch Securely**](#using-pytorch-securely)
|
||||||
- [Untrusted models](#untrusted-models)
|
- [Untrusted models](#untrusted-models)
|
||||||
- [TorchScript models](#torchscript-models)
|
- [TorchScript models](#torchscript-models)
|
||||||
- [Untrusted inputs](#untrusted-inputs)
|
- [Untrusted inputs](#untrusted-inputs)
|
||||||
@ -10,28 +10,28 @@
|
|||||||
- [**CI/CD security principles**](#cicd-security-principles)
|
- [**CI/CD security principles**](#cicd-security-principles)
|
||||||
## Reporting Security Issues
|
## Reporting Security Issues
|
||||||
|
|
||||||
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
|
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
|
||||||
|
|
||||||
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
|
||||||
|
|
||||||
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
|
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
|
||||||
|
|
||||||
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
|
||||||
|
|
||||||
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
|
||||||
|
|
||||||
https://www.facebook.com/whitehat
|
https://www.facebook.com/whitehat
|
||||||
|
|
||||||
|
|
||||||
## Using PyTorch Securely
|
## Using Pytorch Securely
|
||||||
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
|
||||||
|
|
||||||
### Untrusted models
|
### Untrusted models
|
||||||
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
|
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
|
||||||
|
|
||||||
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
|
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
|
||||||
|
|
||||||
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
|
||||||
|
|
||||||
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
|
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
|
|||||||
|
|
||||||
### TorchScript models
|
### TorchScript models
|
||||||
|
|
||||||
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
|
||||||
|
|
||||||
### Untrusted inputs during training and prediction
|
### Untrusted inputs during training and prediction
|
||||||
|
|
||||||
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
|
|||||||
|
|
||||||
### Data privacy
|
### Data privacy
|
||||||
|
|
||||||
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
|
||||||
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
|
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
|
||||||
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
|
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
|
||||||
|
|
||||||
### Using distributed features
|
### Using distributed features
|
||||||
|
|
||||||
|
|||||||
@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
|
|||||||
if(USE_CUDA)
|
if(USE_CUDA)
|
||||||
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
|
||||||
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
|
||||||
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*")
|
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
|
||||||
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
|
||||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
|
||||||
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")
|
||||||
|
|||||||
@ -181,7 +181,7 @@ c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
|
|||||||
static const size_t size = sizeof(CPUGeneratorImplState);
|
static const size_t size = sizeof(CPUGeneratorImplState);
|
||||||
static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType");
|
static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType");
|
||||||
|
|
||||||
auto state_tensor = at::detail::empty_cpu({static_cast<int64_t>(size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||||
auto rng_state = state_tensor.data_ptr();
|
auto rng_state = state_tensor.data_ptr();
|
||||||
|
|
||||||
// accumulate generator data to be copied into byte tensor
|
// accumulate generator data to be copied into byte tensor
|
||||||
|
|||||||
@ -23,6 +23,8 @@ C10_DIAGNOSTIC_POP()
|
|||||||
#endif
|
#endif
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
/*
|
/*
|
||||||
These const variables defined the fp32 precisions for different backend
|
These const variables defined the fp32 precisions for different backend
|
||||||
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
|
||||||
@ -39,6 +41,16 @@ namespace at {
|
|||||||
->rnn
|
->rnn
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
|
||||||
|
TORCH_WARN_ONCE(
|
||||||
|
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
|
||||||
|
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
|
||||||
|
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
|
||||||
|
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
Float32Backend str2backend(const std::string& name) {
|
Float32Backend str2backend(const std::string& name) {
|
||||||
if (name == "generic")
|
if (name == "generic")
|
||||||
return Float32Backend::GENERIC;
|
return Float32Backend::GENERIC;
|
||||||
@ -194,6 +206,7 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
|
|||||||
} else {
|
} else {
|
||||||
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
|
||||||
}
|
}
|
||||||
|
warn_deprecated_fp32_precision_api();
|
||||||
return allow_tf32_cudnn;
|
return allow_tf32_cudnn;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -201,6 +214,7 @@ void Context::setAllowTF32CuDNN(bool b) {
|
|||||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||||
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
|
||||||
allow_tf32_cudnn = b;
|
allow_tf32_cudnn = b;
|
||||||
|
warn_deprecated_fp32_precision_api();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
||||||
@ -209,7 +223,7 @@ void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
|
|||||||
"setSDPPriority order expected ", sdp_priority_order.size() - 1, " but got ",
|
"setSDPPriority order expected ", sdp_priority_order.size() - 1, " but got ",
|
||||||
at::num_sdp_backends, " unique backends specified in priority order.");
|
at::num_sdp_backends, " unique backends specified in priority order.");
|
||||||
for (uint32_t i = 0; i < order.size(); i++) {
|
for (uint32_t i = 0; i < order.size(); i++) {
|
||||||
sdp_priority_order[i] = static_cast<at::SDPBackend>(order[i]);
|
sdp_priority_order[i] = (at::SDPBackend) order[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -311,6 +325,7 @@ bool Context::allowTF32CuBLAS() const {
|
|||||||
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
|
||||||
"We suggest only using the new API to set the TF32 flag. See also: ",
|
"We suggest only using the new API to set the TF32 flag. See also: ",
|
||||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||||
|
warn_deprecated_fp32_precision_api();
|
||||||
return allow_tf32_new;
|
return allow_tf32_new;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -334,6 +349,7 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
|
|||||||
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
|
||||||
"We suggest only using the new API for matmul precision. See also: ",
|
"We suggest only using the new API for matmul precision. See also: ",
|
||||||
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
|
||||||
|
warn_deprecated_fp32_precision_api();
|
||||||
return float32_matmul_precision;
|
return float32_matmul_precision;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -361,6 +377,7 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
|
|||||||
|
|
||||||
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
void Context::setFloat32MatmulPrecision(const std::string &s) {
|
||||||
auto match = [this](const std::string & s_) {
|
auto match = [this](const std::string & s_) {
|
||||||
|
warn_deprecated_fp32_precision_api();
|
||||||
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
|
||||||
if (s_ == "highest") {
|
if (s_ == "highest") {
|
||||||
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
|
||||||
@ -808,14 +825,6 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) {
|
|||||||
display_vmap_fallback_warnings_ = enabled;
|
display_vmap_fallback_warnings_ = enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Context::warnOnAccumulateGradStreamMismatch() const {
|
|
||||||
return warn_on_accumulate_grad_stream_mismatch_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) {
|
|
||||||
warn_on_accumulate_grad_stream_mismatch_ = enabled;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Context::isDefaultMobileCPUAllocatorSet() {
|
bool Context::isDefaultMobileCPUAllocatorSet() {
|
||||||
return prev_allocator_ptr_ != nullptr;
|
return prev_allocator_ptr_ != nullptr;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -174,12 +174,6 @@ class TORCH_API Context {
|
|||||||
static long versionCuDNN() {
|
static long versionCuDNN() {
|
||||||
return detail::getCUDAHooks().versionCuDNN();
|
return detail::getCUDAHooks().versionCuDNN();
|
||||||
}
|
}
|
||||||
static long versionRuntimeCuDNN() {
|
|
||||||
return detail::getCUDAHooks().versionRuntimeCuDNN();
|
|
||||||
}
|
|
||||||
static long versionCuDNNFrontend() {
|
|
||||||
return detail::getCUDAHooks().versionCuDNNFrontend();
|
|
||||||
}
|
|
||||||
static bool hasCuSOLVER() {
|
static bool hasCuSOLVER() {
|
||||||
return detail::getCUDAHooks().hasCuSOLVER();
|
return detail::getCUDAHooks().hasCuSOLVER();
|
||||||
}
|
}
|
||||||
@ -410,9 +404,6 @@ class TORCH_API Context {
|
|||||||
void setDisplayVmapFallbackWarnings(bool enabled);
|
void setDisplayVmapFallbackWarnings(bool enabled);
|
||||||
bool areVmapFallbackWarningsEnabled() const;
|
bool areVmapFallbackWarningsEnabled() const;
|
||||||
|
|
||||||
void setWarnOnAccumulateGradStreamMismatch(bool enabled);
|
|
||||||
bool warnOnAccumulateGradStreamMismatch() const;
|
|
||||||
|
|
||||||
bool isDefaultMobileCPUAllocatorSet();
|
bool isDefaultMobileCPUAllocatorSet();
|
||||||
void setDefaultMobileCPUAllocator();
|
void setDefaultMobileCPUAllocator();
|
||||||
void unsetDefaultMobileCPUAllocator();
|
void unsetDefaultMobileCPUAllocator();
|
||||||
@ -503,7 +494,6 @@ class TORCH_API Context {
|
|||||||
bool release_original_weights = false;
|
bool release_original_weights = false;
|
||||||
#endif
|
#endif
|
||||||
bool display_vmap_fallback_warnings_ = false;
|
bool display_vmap_fallback_warnings_ = false;
|
||||||
bool warn_on_accumulate_grad_stream_mismatch_ = true;
|
|
||||||
std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
|
std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
|
||||||
bool enable_sparse_tensor_invariant_checks = false;
|
bool enable_sparse_tensor_invariant_checks = false;
|
||||||
bool allow_fp16_reduction_cpu = false;
|
bool allow_fp16_reduction_cpu = false;
|
||||||
|
|||||||
@ -6,7 +6,6 @@
|
|||||||
#include <c10/util/Half.h>
|
#include <c10/util/Half.h>
|
||||||
#include <c10/util/Metaprogramming.h>
|
#include <c10/util/Metaprogramming.h>
|
||||||
#include <c10/util/complex.h>
|
#include <c10/util/complex.h>
|
||||||
#include <torch/headeronly/core/Dispatch.h>
|
|
||||||
|
|
||||||
#ifdef __CUDACC__
|
#ifdef __CUDACC__
|
||||||
#include <cuda.h> // For CUDA_VERSION
|
#include <cuda.h> // For CUDA_VERSION
|
||||||
@ -62,9 +61,12 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
case enum_type: { \
|
||||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
|
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
||||||
|
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
}
|
||||||
|
|
||||||
#define AT_DISPATCH_CASE(enum_type, ...) \
|
#define AT_DISPATCH_CASE(enum_type, ...) \
|
||||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
||||||
@ -93,6 +95,14 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
|||||||
return __VA_ARGS__(); \
|
return __VA_ARGS__(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
// The AT_DISPATCH_* family of macros provides the ability to
|
// The AT_DISPATCH_* family of macros provides the ability to
|
||||||
// conveniently generate specializations of a kernel over all of the
|
// conveniently generate specializations of a kernel over all of the
|
||||||
// dtypes we care about in PyTorch. We call it "dispatch" because
|
// dtypes we care about in PyTorch. We call it "dispatch" because
|
||||||
@ -180,13 +190,25 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
|||||||
// but we're just being safe (and it doesn't hurt.) Note we must
|
// but we're just being safe (and it doesn't hurt.) Note we must
|
||||||
// use it to shut up warnings about unused store.
|
// use it to shut up warnings about unused store.
|
||||||
|
|
||||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||||
THO_DISPATCH_SWITCH_TMPL( \
|
[&] { \
|
||||||
RECORD_KERNEL_FUNCTION_DTYPE, \
|
const auto& the_type = TYPE; \
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED, \
|
constexpr const char* at_dispatch_name = NAME; \
|
||||||
TYPE, \
|
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||||
NAME, \
|
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
||||||
__VA_ARGS__)
|
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
||||||
|
switch (_st) { \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
default: \
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED( \
|
||||||
|
false, \
|
||||||
|
'"', \
|
||||||
|
at_dispatch_name, \
|
||||||
|
"\" not implemented for '", \
|
||||||
|
toString(_st), \
|
||||||
|
"'"); \
|
||||||
|
} \
|
||||||
|
}()
|
||||||
|
|
||||||
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||||
|
|||||||
@ -1,8 +1,3 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
|
||||||
|
|
||||||
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
|
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
|
|
||||||
// This is a new implementation of the AT_DISPATCH macro family from
|
// This is a new implementation of the AT_DISPATCH macro family from
|
||||||
@ -79,19 +74,41 @@
|
|||||||
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
||||||
// relied on GPT4 to help me get it right.
|
// relied on GPT4 to help me get it right.
|
||||||
|
|
||||||
|
// Public API macros
|
||||||
|
|
||||||
// See documentation above
|
// See documentation above
|
||||||
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||||
THO_DISPATCH_V2_TMPL( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
||||||
AT_DISPATCH_SWITCH, \
|
|
||||||
AT_DISPATCH_CASE, \
|
// This macro lets you pass an arbitrary expression that may contain internal
|
||||||
TYPE, \
|
// commas to another macro without having the commas causing the expression
|
||||||
NAME, \
|
// to be interpreted as being multiple arguments
|
||||||
AT_WRAP(BODY), \
|
#define AT_WRAP(...) __VA_ARGS__
|
||||||
__VA_ARGS__)
|
|
||||||
|
#define AT_FLOAT8_TYPES \
|
||||||
|
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
||||||
|
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
|
||||||
|
|
||||||
|
#define AT_INTEGRAL_TYPES \
|
||||||
|
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
||||||
|
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
||||||
|
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
||||||
|
#define AT_INTEGRAL_TYPES_V2 \
|
||||||
|
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||||
|
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
||||||
|
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
||||||
|
// NB: not *actually* all types
|
||||||
|
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||||
|
#define AT_ALL_TYPES_AND_COMPLEX \
|
||||||
|
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
||||||
|
|
||||||
|
// Helper macros
|
||||||
|
|
||||||
// Unused helper macros, kept for BC:
|
|
||||||
#define AT_AP_VAR(N, T, ...) \
|
#define AT_AP_VAR(N, T, ...) \
|
||||||
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
||||||
|
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
||||||
|
#define AT_CONCAT_AUX(a, b) a##b
|
||||||
|
#define AT_EXPAND(X) X
|
||||||
|
|
||||||
// Ensure we never have too many scalar types for the expansion here to
|
// Ensure we never have too many scalar types for the expansion here to
|
||||||
// support. To bump this, you must regenerate the macros below.
|
// support. To bump this, you must regenerate the macros below.
|
||||||
@ -102,6 +119,12 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
|
|||||||
|
|
||||||
num_args = 60
|
num_args = 60
|
||||||
|
|
||||||
|
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
||||||
|
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
||||||
|
|
||||||
|
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
||||||
|
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
||||||
|
|
||||||
for i in range(1, num_args+1):
|
for i in range(1, num_args+1):
|
||||||
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
||||||
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
||||||
@ -112,6 +135,8 @@ for i in range(1, num_args+1):
|
|||||||
// Begin generated code
|
// Begin generated code
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
||||||
|
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||||
|
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
|
||||||
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
||||||
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
||||||
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
||||||
|
|||||||
@ -252,13 +252,13 @@ MapAllocator::MapAllocator(WithFd /*unused*/, std::string_view filename, int fd,
|
|||||||
if (!(flags_ & ALLOCATOR_MAPPED_FROMFD)) {
|
if (!(flags_ & ALLOCATOR_MAPPED_FROMFD)) {
|
||||||
if (flags_ & ALLOCATOR_MAPPED_SHARED) {
|
if (flags_ & ALLOCATOR_MAPPED_SHARED) {
|
||||||
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
||||||
if ((fd = open(filename_.c_str(), flags, static_cast<mode_t>(0600))) == -1) {
|
if ((fd = open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
|
||||||
TORCH_CHECK(false, "unable to open file <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
|
TORCH_CHECK(false, "unable to open file <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
|
||||||
}
|
}
|
||||||
} else if (flags_ & ALLOCATOR_MAPPED_SHAREDMEM) {
|
} else if (flags_ & ALLOCATOR_MAPPED_SHAREDMEM) {
|
||||||
#ifdef HAVE_SHM_OPEN
|
#ifdef HAVE_SHM_OPEN
|
||||||
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
|
||||||
if((fd = shm_open(filename_.c_str(), flags, static_cast<mode_t>(0600))) == -1) {
|
if((fd = shm_open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
|
||||||
TORCH_CHECK(false, "unable to open shared memory object <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
|
TORCH_CHECK(false, "unable to open shared memory object <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
@ -503,7 +503,7 @@ RefcountedMapAllocator::RefcountedMapAllocator(WithFd /*unused*/, const char *fi
|
|||||||
|
|
||||||
void RefcountedMapAllocator::initializeAlloc() {
|
void RefcountedMapAllocator::initializeAlloc() {
|
||||||
TORCH_CHECK(base_ptr_, "base_ptr_ is null");
|
TORCH_CHECK(base_ptr_, "base_ptr_ is null");
|
||||||
MapInfo *map_info = static_cast<MapInfo*>(base_ptr_);
|
MapInfo *map_info = (MapInfo*)base_ptr_;
|
||||||
|
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
ReleaseContext* r_ctx = new ReleaseContext;
|
ReleaseContext* r_ctx = new ReleaseContext;
|
||||||
@ -539,7 +539,7 @@ void RefcountedMapAllocator::close() {
|
|||||||
}
|
}
|
||||||
#else /* _WIN32 */
|
#else /* _WIN32 */
|
||||||
|
|
||||||
MapInfo *info = static_cast<MapInfo*>(data);
|
MapInfo *info = (MapInfo*)(data);
|
||||||
if (--info->refcount == 0) {
|
if (--info->refcount == 0) {
|
||||||
#ifdef HAVE_SHM_UNLINK
|
#ifdef HAVE_SHM_UNLINK
|
||||||
if (shm_unlink(filename_.c_str()) == -1) {
|
if (shm_unlink(filename_.c_str()) == -1) {
|
||||||
|
|||||||
@ -862,7 +862,7 @@ void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
|
|||||||
shape_[dim] = size;
|
shape_[dim] = size;
|
||||||
view_offsets_[dim] += start;
|
view_offsets_[dim] += start;
|
||||||
for (auto& op : operands_) {
|
for (auto& op : operands_) {
|
||||||
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[dim] * start;
|
op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
|
||||||
}
|
}
|
||||||
if (size == 1 && !is_reduction_) {
|
if (size == 1 && !is_reduction_) {
|
||||||
coalesce_dimensions();
|
coalesce_dimensions();
|
||||||
@ -873,7 +873,7 @@ void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indic
|
|||||||
TORCH_INTERNAL_ASSERT(start_dim <= ndim());
|
TORCH_INTERNAL_ASSERT(start_dim <= ndim());
|
||||||
for (const auto i : c10::irange(start_dim, ndim())) {
|
for (const auto i : c10::irange(start_dim, ndim())) {
|
||||||
for (auto& op : operands_) {
|
for (auto& op : operands_) {
|
||||||
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[i] * indices[i - start_dim];
|
op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
|
||||||
}
|
}
|
||||||
shape_[i] = 1;
|
shape_[i] = 1;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -41,7 +41,7 @@ inline void serial_for_each(
|
|||||||
IntArrayRef strides,
|
IntArrayRef strides,
|
||||||
char** base_ptrs,
|
char** base_ptrs,
|
||||||
size_t ntensors,
|
size_t ntensors,
|
||||||
TensorIteratorBase::loop2d_t loop,
|
typename TensorIteratorBase::loop2d_t loop,
|
||||||
Range range) {
|
Range range) {
|
||||||
const auto ndim = shape.size();
|
const auto ndim = shape.size();
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||||
|
|||||||
@ -72,16 +72,10 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) {
|
|||||||
m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>);
|
m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>);
|
||||||
|
|
||||||
m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||||
m.impl("rand_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
|
||||||
m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||||
m.impl("randn_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
|
||||||
|
|
||||||
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||||
m.impl("randint_like.Tensor", unsupportedRandomOp<const Tensor&, const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
|
||||||
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
||||||
m.impl("randint_like.generator", unsupportedRandomOp<const Tensor&, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
|
||||||
m.impl("randint_like.Tensor_generator", unsupportedRandomOp<const Tensor&, const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
|
||||||
m.impl("randint_like.low_generator_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
|
|
||||||
|
|
||||||
m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
|
m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
|
||||||
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
|
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
|
||||||
|
|||||||
@ -190,14 +190,12 @@ class IListRef;
|
|||||||
* it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
|
* it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
|
||||||
*/
|
*/
|
||||||
#define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
|
#define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
|
||||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
|
||||||
switch (TAG) { \
|
switch (TAG) { \
|
||||||
TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
|
TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \
|
TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \
|
||||||
} \
|
}
|
||||||
C10_DIAGNOSTIC_POP()
|
|
||||||
|
|
||||||
enum class IListRefTag {
|
enum class IListRefTag {
|
||||||
#define DEFINE_TAG(tag, ...) tag,
|
#define DEFINE_TAG(tag, ...) tag,
|
||||||
|
|||||||
@ -56,7 +56,7 @@ C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
|
|||||||
* in this overloaded version
|
* in this overloaded version
|
||||||
*/
|
*/
|
||||||
template <typename T, typename V>
|
template <typename T, typename V>
|
||||||
C10_HOST_DEVICE inline std::enable_if_t<!std::is_floating_point_v<T>, T>uniform_int(V val) {
|
C10_HOST_DEVICE inline std::enable_if_t<!(std::is_floating_point_v<T>), T>uniform_int(V val) {
|
||||||
if constexpr (std::is_same_v<T, bool>) {
|
if constexpr (std::is_same_v<T, bool>) {
|
||||||
return static_cast<bool>(val & 1);
|
return static_cast<bool>(val & 1);
|
||||||
} else if constexpr (std::is_same_v<T, int64_t>) {
|
} else if constexpr (std::is_same_v<T, int64_t>) {
|
||||||
|
|||||||
@ -114,25 +114,25 @@ inline typename remove_symint<T>::type unpackSymInt(T x) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
|
inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
|
||||||
return x.guard_int(__FILE__, __LINE__);
|
return x.guard_int(__FILE__, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
|
inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
|
||||||
c10::SymIntArrayRef x) {
|
c10::SymIntArrayRef x) {
|
||||||
return C10_AS_INTARRAYREF_SLOW(x);
|
return C10_AS_INTARRAYREF_SLOW(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
|
inline typename remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
|
||||||
std::optional<c10::SymInt> x) {
|
std::optional<c10::SymInt> x) {
|
||||||
return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
|
return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
|
||||||
: std::nullopt;
|
: std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
|
inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
|
||||||
at::OptionalSymIntArrayRef x) {
|
at::OptionalSymIntArrayRef x) {
|
||||||
return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
|
return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
|
||||||
: std::nullopt;
|
: std::nullopt;
|
||||||
|
|||||||
@ -631,8 +631,8 @@ call_functor_with_args_from_stack_(
|
|||||||
Stack* stack,
|
Stack* stack,
|
||||||
std::index_sequence<ivalue_arg_indices...> /*unused*/,
|
std::index_sequence<ivalue_arg_indices...> /*unused*/,
|
||||||
guts::typelist::typelist<ArgTypes...>* /*unused*/) {
|
guts::typelist::typelist<ArgTypes...>* /*unused*/) {
|
||||||
(void)stack; // when sizeof...(ivalue_arg_indices) == 0, this argument would
|
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
|
||||||
// be unused and we have to silence the compiler warning.
|
// be unused and we have to silence the compiler warning.
|
||||||
|
|
||||||
// We're explicitly filtering out DispatchKeySet from the argument list.
|
// We're explicitly filtering out DispatchKeySet from the argument list.
|
||||||
// Some kernels take a DispatchKeySet as their first argument in order to
|
// Some kernels take a DispatchKeySet as their first argument in order to
|
||||||
|
|||||||
@ -18,7 +18,6 @@ struct TORCH_API EnumType : public NamedType {
|
|||||||
TypePtr value,
|
TypePtr value,
|
||||||
std::vector<EnumNameValue> enum_names_values,
|
std::vector<EnumNameValue> enum_names_values,
|
||||||
std::weak_ptr<::torch::jit::CompilationUnit> cu) {
|
std::weak_ptr<::torch::jit::CompilationUnit> cu) {
|
||||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
|
|
||||||
switch (value->kind()) {
|
switch (value->kind()) {
|
||||||
case TypeKind::IntType:
|
case TypeKind::IntType:
|
||||||
case TypeKind::FloatType:
|
case TypeKind::FloatType:
|
||||||
@ -35,7 +34,6 @@ struct TORCH_API EnumType : public NamedType {
|
|||||||
value->str(),
|
value->str(),
|
||||||
"', only int, float and string are supported");
|
"', only int, float and string are supported");
|
||||||
}
|
}
|
||||||
C10_DIAGNOSTIC_POP()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string str() const override {
|
std::string str() const override {
|
||||||
|
|||||||
@ -601,8 +601,8 @@ std::ostream& IValue::repr(
|
|||||||
double d = v.toDouble();
|
double d = v.toDouble();
|
||||||
int c = std::fpclassify(d);
|
int c = std::fpclassify(d);
|
||||||
if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
|
if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
|
||||||
int64_t i = static_cast<int64_t>(d);
|
int64_t i = int64_t(d);
|
||||||
if (static_cast<double>(i) == d) {
|
if (double(i) == d) {
|
||||||
// -0.0 (signed zero) needs to be parsed as -0.
|
// -0.0 (signed zero) needs to be parsed as -0.
|
||||||
if (i == 0 && std::signbit(d)) {
|
if (i == 0 && std::signbit(d)) {
|
||||||
return out << "-" << i << ".";
|
return out << "-" << i << ".";
|
||||||
@ -799,8 +799,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
|
|||||||
double d = v.toDouble();
|
double d = v.toDouble();
|
||||||
int c = std::fpclassify(d);
|
int c = std::fpclassify(d);
|
||||||
if (c == FP_NORMAL || c == FP_ZERO) {
|
if (c == FP_NORMAL || c == FP_ZERO) {
|
||||||
int64_t i = static_cast<int64_t>(d);
|
int64_t i = int64_t(d);
|
||||||
if (static_cast<double>(i) == d) {
|
if (double(i) == d) {
|
||||||
return out << i << ".";
|
return out << i << ".";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -41,7 +41,7 @@ void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
|
|||||||
inline bool is_contiguous_strides(
|
inline bool is_contiguous_strides(
|
||||||
const IntArrayRef sizes,
|
const IntArrayRef sizes,
|
||||||
const IntArrayRef strides) {
|
const IntArrayRef strides) {
|
||||||
size_t n_dim = sizes.size();
|
int n_dim = static_cast<int>(sizes.size());
|
||||||
if (n_dim == 0) {
|
if (n_dim == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -50,7 +50,7 @@ inline bool is_contiguous_strides(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = static_cast<int>(n_dim) - 2; i >= 0; i--) {
|
for (int i = n_dim - 2; i >= 0; i--) {
|
||||||
if (strides[i] != strides[i + 1] * sizes[i + 1]) {
|
if (strides[i] != strides[i + 1] * sizes[i + 1]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -922,7 +922,6 @@ struct TORCH_API DictType : public SharedType {
|
|||||||
if (auto dyn = key->castRaw<DynamicType>()) {
|
if (auto dyn = key->castRaw<DynamicType>()) {
|
||||||
kind = dyn->dynamicKind();
|
kind = dyn->dynamicKind();
|
||||||
}
|
}
|
||||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
|
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
case TypeKind::AnyType:
|
case TypeKind::AnyType:
|
||||||
case TypeKind::IntType:
|
case TypeKind::IntType:
|
||||||
@ -939,7 +938,6 @@ struct TORCH_API DictType : public SharedType {
|
|||||||
key->str(),
|
key->str(),
|
||||||
"', only int, float, complex, Tensor, device and string keys are supported");
|
"', only int, float, complex, Tensor, device and string keys are supported");
|
||||||
}
|
}
|
||||||
C10_DIAGNOSTIC_POP()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// aligned with the format in FunctionSchema
|
// aligned with the format in FunctionSchema
|
||||||
@ -2373,7 +2371,7 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
inline detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
|
inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
|
||||||
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
|
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
|
||||||
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
|
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
|
||||||
return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this());
|
return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this());
|
||||||
@ -2382,7 +2380,7 @@ inline detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<>
|
template<>
|
||||||
inline detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
|
inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
|
||||||
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
|
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
|
||||||
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
|
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
|
||||||
return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this());
|
return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this());
|
||||||
|
|||||||
@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
|
|||||||
auto vals = svreinterpret_u16_bf16(values);
|
auto vals = svreinterpret_u16_bf16(values);
|
||||||
vals = sveor_u16_x(ptrue, vals, mask);
|
vals = sveor_u16_x(ptrue, vals, mask);
|
||||||
return svreinterpret_bf16_u16(vals);
|
return svreinterpret_bf16_u16(vals);
|
||||||
}
|
};
|
||||||
Vectorized<BFloat16> round() const;
|
Vectorized<BFloat16> round() const;
|
||||||
Vectorized<BFloat16> tan() const;
|
Vectorized<BFloat16> tan() const;
|
||||||
Vectorized<BFloat16> tanh() const;
|
Vectorized<BFloat16> tanh() const;
|
||||||
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
|
|||||||
return convert_float_bfloat16(v1, v2); \
|
return convert_float_bfloat16(v1, v2); \
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan)
|
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle)
|
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos)
|
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin)
|
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan)
|
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf)
|
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc)
|
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp)
|
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2)
|
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1)
|
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0)
|
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e)
|
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma)
|
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(log)
|
DEFINE_BF16_FUNC_VIA_FLOAT(log);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2)
|
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10)
|
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p)
|
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin)
|
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos)
|
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil)
|
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor)
|
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(round)
|
DEFINE_BF16_FUNC_VIA_FLOAT(round);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan)
|
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc)
|
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma)
|
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt)
|
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal)
|
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt)
|
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
|
||||||
|
|
||||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
|
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
|
||||||
const Vectorized<BFloat16>& other) const {
|
const Vectorized<BFloat16>& other) const {
|
||||||
|
|||||||
@ -19,13 +19,6 @@ inline namespace CPU_CAPABILITY {
|
|||||||
#error "Big endian is not supported."
|
#error "Big endian is not supported."
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// GCC does not properly optimize bf16 operators
|
|
||||||
#if defined(__ARM_FEATURE_BF16) && (__clang_major__ >= 19)
|
|
||||||
#define BF16_ARITHMETIC_SUPPORTED() 1
|
|
||||||
#else
|
|
||||||
#define BF16_ARITHMETIC_SUPPORTED() 0
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Unlike the float16_t family of types, bfloat16_t is not available
|
// Unlike the float16_t family of types, bfloat16_t is not available
|
||||||
// when we're not targeting bfloat16 hardware support on some
|
// when we're not targeting bfloat16 hardware support on some
|
||||||
// platforms (but not Mac, so we have to be careful not to shadow the
|
// platforms (but not Mac, so we have to be careful not to shadow the
|
||||||
@ -359,72 +352,18 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
|
|||||||
other, &Vectorized<float>::name); \
|
other, &Vectorized<float>::name); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
||||||
Vectorized frac() const;
|
Vectorized frac() const;
|
||||||
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
|
||||||
|
|
||||||
#ifdef __ARM_FEATURE_BF16
|
|
||||||
// Flip sign bit
|
|
||||||
Vectorized<c10::BFloat16> neg() const {
|
|
||||||
return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768));
|
|
||||||
}
|
|
||||||
// Fast reciprocal is fine because we are truncating results
|
|
||||||
Vectorized<c10::BFloat16> reciprocal() const {
|
|
||||||
auto x = vcvtq_low_f32_bf16(values);
|
|
||||||
auto y = vcvtq_high_f32_bf16(values);
|
|
||||||
x = vrecpeq_f32(x);
|
|
||||||
y = vrecpeq_f32(y);
|
|
||||||
return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y);
|
|
||||||
}
|
|
||||||
// Clearing the sign bit
|
|
||||||
Vectorized<c10::BFloat16> abs() const {
|
|
||||||
return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
|
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
|
|
||||||
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
|
||||||
#endif
|
|
||||||
|
|
||||||
// These functions are optimized on clang-21+
|
|
||||||
#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21)
|
|
||||||
Vectorized<c10::BFloat16> operator==(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values == other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator!=(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values != other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator<(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values < other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator<=(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values <= other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator>(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values > other.values;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vectorized<c10::BFloat16> operator>=(
|
|
||||||
const Vectorized<c10::BFloat16>& other) const {
|
|
||||||
return values >= other.values;
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
|
||||||
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
|
||||||
#endif
|
|
||||||
|
|
||||||
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||||
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
|
||||||
@ -473,52 +412,28 @@ template <>
|
|||||||
Vectorized<c10::BFloat16> inline operator+(
|
Vectorized<c10::BFloat16> inline operator+(
|
||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b) {
|
const Vectorized<c10::BFloat16>& b) {
|
||||||
#if BF16_ARITHMETIC_SUPPORTED()
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
return x + y;
|
|
||||||
#else
|
|
||||||
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<c10::BFloat16> inline operator-(
|
Vectorized<c10::BFloat16> inline operator-(
|
||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b) {
|
const Vectorized<c10::BFloat16>& b) {
|
||||||
#if BF16_ARITHMETIC_SUPPORTED()
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
return x - y;
|
|
||||||
#else
|
|
||||||
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<c10::BFloat16> inline operator*(
|
Vectorized<c10::BFloat16> inline operator*(
|
||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b) {
|
const Vectorized<c10::BFloat16>& b) {
|
||||||
#if BF16_ARITHMETIC_SUPPORTED()
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
return x * y;
|
|
||||||
#else
|
|
||||||
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
Vectorized<c10::BFloat16> inline operator/(
|
Vectorized<c10::BFloat16> inline operator/(
|
||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b) {
|
const Vectorized<c10::BFloat16>& b) {
|
||||||
#if BF16_ARITHMETIC_SUPPORTED()
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
return x / y;
|
|
||||||
#else
|
|
||||||
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// frac. Implement this here so we can use subtraction
|
// frac. Implement this here so we can use subtraction
|
||||||
@ -629,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd(
|
|||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b,
|
const Vectorized<c10::BFloat16>& b,
|
||||||
const Vectorized<c10::BFloat16>& c) {
|
const Vectorized<c10::BFloat16>& c) {
|
||||||
#if BF16_ARITHMETIC_SUPPORTED()
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
bfloat16x8_t z = c;
|
|
||||||
return x * y + z;
|
|
||||||
#else
|
|
||||||
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
|
||||||
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
|
||||||
// elements, not the bottom and top half, so they don't seem
|
// elements, not the bottom and top half, so they don't seem
|
||||||
// particularly useful here. Ideally we would include dot product in
|
// particularly useful here. Ideally we would include dot product in
|
||||||
// the Vectorized interface...
|
// the Vectorized interface...
|
||||||
return a * b + c;
|
return a * b + c;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -649,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd(
|
|||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b,
|
const Vectorized<c10::BFloat16>& b,
|
||||||
const Vectorized<c10::BFloat16>& c) {
|
const Vectorized<c10::BFloat16>& c) {
|
||||||
#if BF16_ARITHMETIC_SUPPORTED()
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
bfloat16x8_t z = c;
|
|
||||||
return (-x) * y + z;
|
|
||||||
#else
|
|
||||||
// See NOTE [BF16 FMA] above.
|
// See NOTE [BF16 FMA] above.
|
||||||
return -a * b + c;
|
return -a * b + c;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -665,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub(
|
|||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b,
|
const Vectorized<c10::BFloat16>& b,
|
||||||
const Vectorized<c10::BFloat16>& c) {
|
const Vectorized<c10::BFloat16>& c) {
|
||||||
#if BF16_ARITHMETIC_SUPPORTED()
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
bfloat16x8_t z = c;
|
|
||||||
return x * y - z;
|
|
||||||
#else
|
|
||||||
// See NOTE [BF16 FMA] above.
|
// See NOTE [BF16 FMA] above.
|
||||||
return a * b - c;
|
return a * b - c;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -681,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub(
|
|||||||
const Vectorized<c10::BFloat16>& a,
|
const Vectorized<c10::BFloat16>& a,
|
||||||
const Vectorized<c10::BFloat16>& b,
|
const Vectorized<c10::BFloat16>& b,
|
||||||
const Vectorized<c10::BFloat16>& c) {
|
const Vectorized<c10::BFloat16>& c) {
|
||||||
#if BF16_ARITHMETIC_SUPPORTED()
|
|
||||||
bfloat16x8_t x = a;
|
|
||||||
bfloat16x8_t y = b;
|
|
||||||
bfloat16x8_t z = c;
|
|
||||||
return (-x) * y - z;
|
|
||||||
#else
|
|
||||||
// See NOTE [BF16 FMA] above.
|
// See NOTE [BF16 FMA] above.
|
||||||
return -a * b - c;
|
return -a * b - c;
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||||
|
|||||||
@ -6,9 +6,9 @@ namespace at::vec {
|
|||||||
inline namespace CPU_CAPABILITY {
|
inline namespace CPU_CAPABILITY {
|
||||||
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
|
||||||
|
|
||||||
// Enable auto-vectorization for clang-17+
|
// Enable auto-vectorization for GCC-13+ and clang-17+
|
||||||
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
|
||||||
#if defined(__clang__) && (__clang_major__ >= 17)
|
#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
|
||||||
|
|
||||||
template <typename from_type, typename to_type>
|
template <typename from_type, typename to_type>
|
||||||
inline void convertImpl(
|
inline void convertImpl(
|
||||||
@ -191,37 +191,22 @@ inline void convert(const at::Half* src, bool* dst, int64_t n) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
#ifdef __ARM_FEATURE_BF16
|
||||||
template <typename to_type>
|
CONVERT_TEMPLATE(bfloat16_t, uint8_t)
|
||||||
inline void convertFromBf16Impl(
|
CONVERT_TEMPLATE(bfloat16_t, int8_t)
|
||||||
const c10::BFloat16* __restrict src,
|
CONVERT_TEMPLATE(bfloat16_t, int16_t)
|
||||||
to_type* __restrict dst,
|
CONVERT_TEMPLATE(bfloat16_t, int32_t)
|
||||||
int64_t n) {
|
CONVERT_TEMPLATE(bfloat16_t, int64_t)
|
||||||
const uint16_t* srcPtr = reinterpret_cast<const uint16_t*>(src);
|
CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
|
||||||
uint64_t len = static_cast<uint64_t>(n);
|
CONVERT_TEMPLATE(bfloat16_t, float)
|
||||||
for (uint64_t i = 0; i < len; i++) {
|
CONVERT_TEMPLATE(bfloat16_t, double)
|
||||||
uint32_t tmp = static_cast<uint32_t>(srcPtr[i]) << 16;
|
CONVERT_TEMPLATE(uint8_t, bfloat16_t)
|
||||||
float tmpF;
|
CONVERT_TEMPLATE(int8_t, bfloat16_t)
|
||||||
__builtin_memcpy(&tmpF, &tmp, sizeof(float));
|
CONVERT_TEMPLATE(int16_t, bfloat16_t)
|
||||||
dst[i] = static_cast<to_type>(tmpF);
|
CONVERT_TEMPLATE(int32_t, bfloat16_t)
|
||||||
}
|
CONVERT_TEMPLATE(int64_t, bfloat16_t)
|
||||||
}
|
CONVERT_TEMPLATE(float, bfloat16_t)
|
||||||
#define CONVERT_FROM_BF16_TEMPLATE(to_type) \
|
CONVERT_TEMPLATE(double, bfloat16_t)
|
||||||
template <> \
|
|
||||||
inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \
|
|
||||||
return convertFromBf16Impl<to_type>(src, dst, n); \
|
|
||||||
}
|
|
||||||
|
|
||||||
CONVERT_FROM_BF16_TEMPLATE(uint8_t)
|
|
||||||
CONVERT_FROM_BF16_TEMPLATE(int8_t)
|
|
||||||
CONVERT_FROM_BF16_TEMPLATE(int16_t)
|
|
||||||
CONVERT_FROM_BF16_TEMPLATE(int32_t)
|
|
||||||
CONVERT_FROM_BF16_TEMPLATE(int64_t)
|
|
||||||
CONVERT_FROM_BF16_TEMPLATE(float)
|
|
||||||
CONVERT_FROM_BF16_TEMPLATE(double)
|
|
||||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
|
||||||
CONVERT_FROM_BF16_TEMPLATE(float16_t)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
inline void convertBoolToBfloat16Impl(
|
inline void convertBoolToBfloat16Impl(
|
||||||
const bool* __restrict src,
|
const bool* __restrict src,
|
||||||
@ -262,6 +247,8 @@ inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename src_t>
|
template <typename src_t>
|
||||||
struct VecConvert<
|
struct VecConvert<
|
||||||
float,
|
float,
|
||||||
|
|||||||
@ -309,7 +309,7 @@ class Vectorized<float> {
|
|||||||
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
|
||||||
// Implementation copied from Arm Optimized Routine
|
// Implementation copied from Arm Optimized Routine
|
||||||
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
|
||||||
inline Vectorized<float> vexpq_f32_u20() const {
|
Vectorized<float> exp_u20() const {
|
||||||
// bail out to sleef if it's a special case:
|
// bail out to sleef if it's a special case:
|
||||||
// i.e. there's an input s.t. |input| > 87.3....
|
// i.e. there's an input s.t. |input| > 87.3....
|
||||||
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
|
||||||
@ -348,9 +348,6 @@ class Vectorized<float> {
|
|||||||
|
|
||||||
return vfmaq_f32(scale, poly, scale);
|
return vfmaq_f32(scale, poly, scale);
|
||||||
}
|
}
|
||||||
Vectorized<float> exp_u20() const {
|
|
||||||
return vexpq_f32_u20();
|
|
||||||
}
|
|
||||||
Vectorized<float> fexp_u20() const {
|
Vectorized<float> fexp_u20() const {
|
||||||
return exp_u20();
|
return exp_u20();
|
||||||
}
|
}
|
||||||
@ -637,7 +634,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
|
|||||||
// - exp(- x * x)
|
// - exp(- x * x)
|
||||||
auto pow_2 = (*this) * (*this);
|
auto pow_2 = (*this) * (*this);
|
||||||
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
|
auto neg_pow_2 = pow_2 ^ neg_zero_vec;
|
||||||
auto tmp4 = neg_pow_2.vexpq_f32_u20();
|
auto tmp4 = neg_pow_2.exp();
|
||||||
auto tmp5 = tmp4 ^ neg_zero_vec;
|
auto tmp5 = tmp4 ^ neg_zero_vec;
|
||||||
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
|
// erf(x) = sign(x) * (1 - r * t * exp(- x * x))
|
||||||
auto tmp6 = t * tmp5;
|
auto tmp6 = t * tmp5;
|
||||||
|
|||||||
@ -514,7 +514,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
|||||||
|
|
||||||
using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
|
using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
|
||||||
using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
|
using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
|
||||||
using value_type = c10::qint8::underlying;
|
using value_type = typename c10::qint8::underlying;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using Vectorizedqi::Vectorizedqi;
|
using Vectorizedqi::Vectorizedqi;
|
||||||
@ -727,7 +727,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
|||||||
|
|
||||||
using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
|
using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
|
||||||
using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
|
using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
|
||||||
using value_type = c10::quint8::underlying;
|
using value_type = typename c10::quint8::underlying;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using Vectorizedqi::Vectorizedqi;
|
using Vectorizedqi::Vectorizedqi;
|
||||||
|
|||||||
@ -567,7 +567,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
|||||||
|
|
||||||
using float_vec_return_type = std::array<Vectorized<float>, 4>;
|
using float_vec_return_type = std::array<Vectorized<float>, 4>;
|
||||||
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
|
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
|
||||||
using value_type = c10::qint8::underlying;
|
using value_type = typename c10::qint8::underlying;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using Vectorizedqi::Vectorizedqi;
|
using Vectorizedqi::Vectorizedqi;
|
||||||
@ -804,7 +804,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
|||||||
|
|
||||||
using float_vec_return_type = std::array<Vectorized<float>, 4>;
|
using float_vec_return_type = std::array<Vectorized<float>, 4>;
|
||||||
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
|
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
|
||||||
using value_type = c10::quint8::underlying;
|
using value_type = typename c10::quint8::underlying;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
using Vectorizedqi::Vectorizedqi;
|
using Vectorizedqi::Vectorizedqi;
|
||||||
|
|||||||
@ -672,7 +672,7 @@ struct Vectorized {
|
|||||||
return map(std::sqrt);
|
return map(std::sqrt);
|
||||||
}
|
}
|
||||||
Vectorized<T> reciprocal() const {
|
Vectorized<T> reciprocal() const {
|
||||||
return map([](T x) { return (T)1 / x; });
|
return map([](T x) { return (T)(1) / x; });
|
||||||
}
|
}
|
||||||
Vectorized<T> rsqrt() const {
|
Vectorized<T> rsqrt() const {
|
||||||
return map([](T x) { return (T)1 / std::sqrt(x); });
|
return map([](T x) { return (T)1 / std::sqrt(x); });
|
||||||
|
|||||||
@ -46,7 +46,7 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
|
|||||||
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
|
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
|
||||||
map(
|
map(
|
||||||
[](const Vectorized<scalar_t>& x) {
|
[](const Vectorized<scalar_t>& x) {
|
||||||
return Vectorized<scalar_t>((scalar_t)1) / x.sqrt();
|
return Vectorized<scalar_t>((scalar_t)(1)) / x.sqrt();
|
||||||
},
|
},
|
||||||
out + begin,
|
out + begin,
|
||||||
in + begin,
|
in + begin,
|
||||||
|
|||||||
@ -388,7 +388,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
at::Half halpha;
|
at::Half halpha;
|
||||||
at::Half hbeta;
|
at::Half hbeta;
|
||||||
uint32_t mask = -1;
|
|
||||||
#endif
|
#endif
|
||||||
void * alpha_ptr = α
|
void * alpha_ptr = α
|
||||||
void * beta_ptr = β
|
void * beta_ptr = β
|
||||||
@ -428,7 +427,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||||
if (fp16_reduction !=
|
if (fp16_reduction !=
|
||||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||||
mask =
|
uint32_t mask =
|
||||||
fp16_reduction ==
|
fp16_reduction ==
|
||||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||||
@ -445,7 +444,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||||
if (bf16_reduction !=
|
if (bf16_reduction !=
|
||||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||||
mask =
|
uint32_t mask =
|
||||||
bf16_reduction ==
|
bf16_reduction ==
|
||||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||||
@ -512,41 +511,17 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||||
int returnedResult = 0;
|
int returnedResult = 0;
|
||||||
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
|
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||||
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
|
ltHandle,
|
||||||
#ifndef USE_ROCM
|
computeDesc.descriptor(),
|
||||||
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
|
Adesc.descriptor(),
|
||||||
#else
|
Bdesc.descriptor(),
|
||||||
const bool lie_to_cublaslt = false;
|
Cdesc.descriptor(),
|
||||||
#endif
|
Cdesc.descriptor(),
|
||||||
if (lie_to_cublaslt) {
|
preference.descriptor(),
|
||||||
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
|
1,
|
||||||
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
|
&heuristicResult,
|
||||||
|
&returnedResult));
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
|
||||||
ltHandle,
|
|
||||||
computeDesc.descriptor(),
|
|
||||||
Adesc.descriptor(),
|
|
||||||
FakeBdesc.descriptor(),
|
|
||||||
FakeCdesc.descriptor(),
|
|
||||||
FakeCdesc.descriptor(),
|
|
||||||
preference.descriptor(),
|
|
||||||
1,
|
|
||||||
&heuristicResult,
|
|
||||||
&returnedResult));
|
|
||||||
} else {
|
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
|
||||||
ltHandle,
|
|
||||||
computeDesc.descriptor(),
|
|
||||||
Adesc.descriptor(),
|
|
||||||
Bdesc.descriptor(),
|
|
||||||
Cdesc.descriptor(),
|
|
||||||
Cdesc.descriptor(),
|
|
||||||
preference.descriptor(),
|
|
||||||
1,
|
|
||||||
&heuristicResult,
|
|
||||||
&returnedResult));
|
|
||||||
}
|
|
||||||
if (returnedResult == 0) {
|
if (returnedResult == 0) {
|
||||||
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
||||||
}
|
}
|
||||||
@ -1597,7 +1572,7 @@ bool gemm_and_bias(
|
|||||||
}
|
}
|
||||||
|
|
||||||
using opmath_t = at::opmath_type<Dtype>;
|
using opmath_t = at::opmath_type<Dtype>;
|
||||||
opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr
|
opmath_t beta_val = 0; // bias is added in epilogue
|
||||||
|
|
||||||
cudaDataType_t abType = CUDA_R_32F;
|
cudaDataType_t abType = CUDA_R_32F;
|
||||||
cudaDataType_t cType = CUDA_R_32F;
|
cudaDataType_t cType = CUDA_R_32F;
|
||||||
@ -1686,22 +1661,15 @@ bool gemm_and_bias(
|
|||||||
_syncCurrentWithCarveoutStream(stream, true);
|
_syncCurrentWithCarveoutStream(stream, true);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
const auto epilogue = [&]() -> cublasLtEpilogue_t {
|
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||||
// The cuBLAS documentation indicates that
|
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||||
// *_<ACTIVATION>_BIAS = *_<ACTIVATION>,
|
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||||
// but we keep it verbose here for clarity.
|
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
|
||||||
switch (activation) {
|
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
|
||||||
case GEMMAndBiasActivationEpilogue::RELU:
|
}
|
||||||
return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU;
|
|
||||||
case GEMMAndBiasActivationEpilogue::GELU:
|
|
||||||
return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU;
|
|
||||||
default:
|
|
||||||
return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT;
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
|
|
||||||
|
|
||||||
if (bias) {
|
if (bias != nullptr) {
|
||||||
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -194,8 +194,8 @@ void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
|
|||||||
void CUDAGeneratorState::capture_prologue() {
|
void CUDAGeneratorState::capture_prologue() {
|
||||||
capturing_ = true;
|
capturing_ = true;
|
||||||
offset_intragraph_ = 0;
|
offset_intragraph_ = 0;
|
||||||
seed_extragraph_.fill_(static_cast<int64_t>(seed_));
|
seed_extragraph_.fill_(int64_t(seed_));
|
||||||
offset_extragraph_.fill_(0);
|
offset_extragraph_.fill_(int64_t(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -216,8 +216,8 @@ void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) {
|
|||||||
at::cuda::assertNotCapturing(
|
at::cuda::assertNotCapturing(
|
||||||
"Cannot prepare for replay during capturing stage.");
|
"Cannot prepare for replay during capturing stage.");
|
||||||
if (wholegraph_increment) {
|
if (wholegraph_increment) {
|
||||||
seed_extragraph_.fill_(static_cast<int64_t>(seed_));
|
seed_extragraph_.fill_(int64_t(seed_));
|
||||||
offset_extragraph_.fill_(static_cast<int64_t>(philox_offset_per_thread_));
|
offset_extragraph_.fill_(int64_t(philox_offset_per_thread_));
|
||||||
// Applies the total increment achieved during previous captures to update the
|
// Applies the total increment achieved during previous captures to update the
|
||||||
// offset.
|
// offset.
|
||||||
increase(wholegraph_increment);
|
increase(wholegraph_increment);
|
||||||
@ -329,7 +329,7 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
|
|||||||
constexpr size_t offset_size = sizeof(int64_t);
|
constexpr size_t offset_size = sizeof(int64_t);
|
||||||
constexpr size_t total_size = seed_size + offset_size;
|
constexpr size_t total_size = seed_size + offset_size;
|
||||||
|
|
||||||
auto state_tensor = at::detail::empty_cpu({static_cast<int64_t>(total_size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
|
||||||
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
auto rng_state = state_tensor.data_ptr<uint8_t>();
|
||||||
auto current_seed = this->current_seed();
|
auto current_seed = this->current_seed();
|
||||||
auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>
|
auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>
|
||||||
|
|||||||
@ -1,90 +1,78 @@
|
|||||||
#include <ATen/cuda/CUDAGreenContext.h>
|
#include <ATen/cuda/CUDAGreenContext.h>
|
||||||
|
|
||||||
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
|
||||||
#include <c10/cuda/driver_api.h>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <vector>
|
|
||||||
#define HAS_CUDA_GREEN_CONTEXT() 1
|
|
||||||
#else
|
|
||||||
#define HAS_CUDA_GREEN_CONTEXT() 0
|
|
||||||
// Suppress unsued private field warnings as this class is not supposed to be called
|
|
||||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field")
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace at::cuda {
|
namespace at::cuda {
|
||||||
|
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
int driver_version;
|
||||||
|
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
||||||
|
TORCH_CHECK(
|
||||||
|
driver_version >= 12080, "cuda driver too old to use green context!");
|
||||||
|
CUcontext pctx = nullptr;
|
||||||
|
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
||||||
|
if (C10_UNLIKELY(!pctx)) {
|
||||||
|
TORCH_WARN(
|
||||||
|
"Attempted to create a green context but"
|
||||||
|
" there was no primary context! Creating a primary context...");
|
||||||
|
|
||||||
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
cudaFree(0);
|
||||||
#if HAS_CUDA_GREEN_CONTEXT()
|
}
|
||||||
int driver_version;
|
|
||||||
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
|
|
||||||
TORCH_CHECK(
|
|
||||||
driver_version >= 12080, "cuda driver too old to use green context!");
|
|
||||||
CUcontext pctx = nullptr;
|
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
|
|
||||||
if (C10_UNLIKELY(!pctx)) {
|
|
||||||
TORCH_WARN(
|
|
||||||
"Attempted to create a green context but"
|
|
||||||
" there was no primary context! Creating a primary context...");
|
|
||||||
|
|
||||||
cudaFree(0);
|
CUdevice device;
|
||||||
}
|
device_id_ = device_id;
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
||||||
|
|
||||||
CUdevice device;
|
// Get device resources
|
||||||
device_id_ = device_id;
|
CUdevResource device_resource;
|
||||||
C10_CUDA_DRIVER_CHECK(
|
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
||||||
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
|
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
||||||
|
|
||||||
// Get device resources
|
// Split resources
|
||||||
CUdevResource device_resource;
|
std::vector<CUdevResource> result(1);
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
|
auto result_data = result.data();
|
||||||
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
|
unsigned int nb_groups = 1;
|
||||||
|
CUdevResource remaining;
|
||||||
|
|
||||||
// Split resources
|
C10_CUDA_DRIVER_CHECK(
|
||||||
std::vector<CUdevResource> result(1);
|
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
||||||
auto result_data = result.data();
|
result_data,
|
||||||
unsigned int nb_groups = 1;
|
&nb_groups,
|
||||||
CUdevResource remaining;
|
&device_resource,
|
||||||
|
&remaining,
|
||||||
|
0, // default flags
|
||||||
|
num_sms));
|
||||||
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
||||||
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
|
|
||||||
result_data,
|
|
||||||
&nb_groups,
|
|
||||||
&device_resource,
|
|
||||||
&remaining,
|
|
||||||
0, // default flags
|
|
||||||
num_sms));
|
|
||||||
|
|
||||||
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
|
// Generate resource descriptor
|
||||||
|
CUdevResourceDesc desc;
|
||||||
|
C10_CUDA_DRIVER_CHECK(
|
||||||
|
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
||||||
|
&desc, result_data, 1));
|
||||||
|
|
||||||
// Generate resource descriptor
|
// Create green context
|
||||||
CUdevResourceDesc desc;
|
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
||||||
C10_CUDA_DRIVER_CHECK(
|
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
||||||
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
|
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
||||||
&desc, result_data, 1));
|
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
||||||
|
|
||||||
// Create green context
|
// Convert to regular context
|
||||||
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
|
C10_CUDA_DRIVER_CHECK(
|
||||||
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
|
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
||||||
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
|
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
||||||
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
|
|
||||||
|
|
||||||
// Convert to regular context
|
|
||||||
C10_CUDA_DRIVER_CHECK(
|
|
||||||
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
|
|
||||||
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
|
|
||||||
#else
|
#else
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<GreenContext> GreenContext::create(
|
std::unique_ptr<GreenContext> GreenContext::create(
|
||||||
uint32_t num_sms,
|
uint32_t num_sms,
|
||||||
std::optional<uint32_t> device_id) {
|
std::optional<uint32_t> device_id) {
|
||||||
#if HAS_CUDA_GREEN_CONTEXT()
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
if (!device_id.has_value()) {
|
if (!device_id.has_value()) {
|
||||||
device_id = at::cuda::current_device();
|
device_id = at::cuda::current_device();
|
||||||
}
|
}
|
||||||
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms));
|
return std::make_unique<GreenContext>(device_id.value(), num_sms);
|
||||||
#else
|
#else
|
||||||
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
#endif
|
#endif
|
||||||
@ -92,7 +80,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
|||||||
|
|
||||||
// Implement move operations
|
// Implement move operations
|
||||||
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
GreenContext::GreenContext(GreenContext&& other) noexcept{
|
||||||
#if HAS_CUDA_GREEN_CONTEXT()
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
device_id_ = std::exchange(other.device_id_, -1);
|
device_id_ = std::exchange(other.device_id_, -1);
|
||||||
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
green_ctx_ = std::exchange(other.green_ctx_, nullptr);
|
||||||
context_ = std::exchange(other.context_, nullptr);
|
context_ = std::exchange(other.context_, nullptr);
|
||||||
@ -103,7 +91,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
|
||||||
#if HAS_CUDA_GREEN_CONTEXT()
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
if (this != &other) {
|
if (this != &other) {
|
||||||
// Clean up current resources
|
// Clean up current resources
|
||||||
if (green_ctx_) {
|
if (green_ctx_) {
|
||||||
@ -132,7 +120,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GreenContext::~GreenContext() noexcept{
|
GreenContext::~GreenContext() noexcept{
|
||||||
#if HAS_CUDA_GREEN_CONTEXT()
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
C10_CUDA_DRIVER_CHECK(
|
C10_CUDA_DRIVER_CHECK(
|
||||||
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
|
||||||
#else
|
#else
|
||||||
@ -140,9 +128,25 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the underlying CUDA context
|
||||||
|
CUcontext GreenContext::getContext() const {
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
return context_;
|
||||||
|
#else
|
||||||
|
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the underlying green context
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
CUgreenCtx GreenContext::getGreenContext() const {
|
||||||
|
return green_ctx_;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Make this context current
|
// Make this context current
|
||||||
void GreenContext::setContext() {
|
void GreenContext::setContext() {
|
||||||
#if HAS_CUDA_GREEN_CONTEXT()
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
auto current_stream = c10::cuda::getCurrentCUDAStream();
|
||||||
parent_stream_ = current_stream.stream();
|
parent_stream_ = current_stream.stream();
|
||||||
|
|
||||||
@ -171,7 +175,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void GreenContext::popContext() {
|
void GreenContext::popContext() {
|
||||||
#if HAS_CUDA_GREEN_CONTEXT()
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
// see above note about stream being hardcoded to the default stream
|
// see above note about stream being hardcoded to the default stream
|
||||||
at::cuda::CUDAEvent ev;
|
at::cuda::CUDAEvent ev;
|
||||||
ev.record(c10::cuda::getCurrentCUDAStream());
|
ev.record(c10::cuda::getCurrentCUDAStream());
|
||||||
|
|||||||
@ -1,38 +1,53 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <ATen/cuda/CUDAEvent.h>
|
#include <ATen/cuda/CUDAEvent.h>
|
||||||
#include <cuda.h>
|
|
||||||
|
|
||||||
// Forward declare green context as opaque ptr
|
#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
|
||||||
typedef struct CUgreenCtx_st* CUgreenCtx;
|
#include <c10/cuda/driver_api.h>
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <vector>
|
||||||
|
#define CUDA_HAS_GREEN_CONTEXT 1
|
||||||
|
#else
|
||||||
|
#define CUDA_HAS_GREEN_CONTEXT 0
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace at::cuda {
|
namespace at::cuda {
|
||||||
|
|
||||||
class TORCH_CUDA_CPP_API GreenContext {
|
class TORCH_CUDA_CPP_API GreenContext {
|
||||||
public:
|
public:
|
||||||
// Green context creation
|
GreenContext(uint32_t device_id, uint32_t num_sms);
|
||||||
static std::unique_ptr<GreenContext> create(
|
|
||||||
uint32_t num_sms,
|
static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
|
||||||
std::optional<uint32_t> device_id);
|
|
||||||
~GreenContext() noexcept;
|
|
||||||
|
|
||||||
// Delete copy constructor and assignment
|
// Delete copy constructor and assignment
|
||||||
GreenContext(const GreenContext&) = delete;
|
GreenContext(const GreenContext&) = delete;
|
||||||
GreenContext& operator=(const GreenContext&) = delete;
|
GreenContext& operator=(const GreenContext&) = delete;
|
||||||
|
|
||||||
|
// Implement move operations
|
||||||
|
GreenContext(GreenContext&& other) noexcept;
|
||||||
|
GreenContext& operator=(GreenContext&& other) noexcept;
|
||||||
|
~GreenContext() noexcept;
|
||||||
|
|
||||||
|
// Get the underlying CUDA context
|
||||||
|
CUcontext getContext() const;
|
||||||
|
|
||||||
|
// Get the underlying green context
|
||||||
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
|
CUgreenCtx getGreenContext() const;
|
||||||
|
#endif
|
||||||
|
|
||||||
// Make this context current
|
// Make this context current
|
||||||
void setContext();
|
void setContext();
|
||||||
|
|
||||||
void popContext();
|
void popContext();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
GreenContext(uint32_t device_id, uint32_t num_sms);
|
#if CUDA_HAS_GREEN_CONTEXT
|
||||||
// Implement move operations
|
|
||||||
GreenContext(GreenContext&& other) noexcept;
|
|
||||||
GreenContext& operator=(GreenContext&& other) noexcept;
|
|
||||||
|
|
||||||
int32_t device_id_ = -1;
|
int32_t device_id_ = -1;
|
||||||
CUgreenCtx green_ctx_ = nullptr;
|
CUgreenCtx green_ctx_ = nullptr;
|
||||||
CUcontext context_ = nullptr;
|
CUcontext context_ = nullptr;
|
||||||
cudaStream_t parent_stream_ = nullptr;
|
cudaStream_t parent_stream_ = nullptr;
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
} // namespace at::cuda
|
} // namespace at::cuda
|
||||||
|
|||||||
@ -7,6 +7,17 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
// hipSparse const API added in v2.4.0
|
||||||
|
#if HIPSPARSE_VERSION >= 200400
|
||||||
|
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
||||||
|
#else
|
||||||
|
#define AT_USE_HIPSPARSE_GENERIC_API() 1
|
||||||
|
#endif
|
||||||
|
#else // USE_ROCM
|
||||||
|
#define AT_USE_HIPSPARSE_GENERIC_API() 0
|
||||||
|
#endif // USE_ROCM
|
||||||
|
|
||||||
// cuSparse Generic API spsv function was added in CUDA 11.3.0
|
// cuSparse Generic API spsv function was added in CUDA 11.3.0
|
||||||
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
|
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
|
||||||
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
|
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1
|
||||||
|
|||||||
@ -155,8 +155,8 @@ size_t parseChosenWorkspaceSize() {
|
|||||||
while (next != end) {
|
while (next != end) {
|
||||||
std::smatch match = *next;
|
std::smatch match = *next;
|
||||||
TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
|
TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
|
||||||
size_t curr_size = std::stoull(match.str(1));
|
size_t curr_size = (size_t) std::stoi(match.str(1));
|
||||||
size_t count = std::stoull(match.str(2));
|
size_t count = (size_t) std::stoi(match.str(2));
|
||||||
total_size += curr_size * 1024 * count;
|
total_size += curr_size * 1024 * count;
|
||||||
next++;
|
next++;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -55,14 +55,6 @@ struct numeric_limits<int8_t> {
|
|||||||
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
|
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct numeric_limits<uint16_t> {
|
|
||||||
static inline __host__ __device__ uint16_t lowest() { return 0; }
|
|
||||||
static inline __host__ __device__ uint16_t max() { return UINT16_MAX; }
|
|
||||||
static inline __host__ __device__ uint16_t lower_bound() { return 0; }
|
|
||||||
static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct numeric_limits<int16_t> {
|
struct numeric_limits<int16_t> {
|
||||||
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
|
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
|
||||||
@ -71,14 +63,6 @@ struct numeric_limits<int16_t> {
|
|||||||
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
|
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct numeric_limits<uint32_t> {
|
|
||||||
static inline __host__ __device__ uint32_t lowest() { return 0; }
|
|
||||||
static inline __host__ __device__ uint32_t max() { return UINT32_MAX; }
|
|
||||||
static inline __host__ __device__ uint32_t lower_bound() { return 0; }
|
|
||||||
static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; }
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct numeric_limits<int32_t> {
|
struct numeric_limits<int32_t> {
|
||||||
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
|
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
|
||||||
@ -87,21 +71,6 @@ struct numeric_limits<int32_t> {
|
|||||||
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
|
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct numeric_limits<uint64_t> {
|
|
||||||
#ifdef _MSC_VER
|
|
||||||
static inline __host__ __device__ uint64_t lowest() { return 0; }
|
|
||||||
static inline __host__ __device__ uint64_t max() { return _UI64_MAX; }
|
|
||||||
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
|
|
||||||
static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; }
|
|
||||||
#else
|
|
||||||
static inline __host__ __device__ uint64_t lowest() { return 0; }
|
|
||||||
static inline __host__ __device__ uint64_t max() { return UINT64_MAX; }
|
|
||||||
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
|
|
||||||
static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; }
|
|
||||||
#endif
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct numeric_limits<int64_t> {
|
struct numeric_limits<int64_t> {
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
|
|||||||
@ -24,13 +24,7 @@ namespace detail {
|
|||||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||||
// the data, so we can save template instantiations by reinterpreting
|
// the data, so we can save template instantiations by reinterpreting
|
||||||
// it as an opaque type.
|
// it as an opaque type.
|
||||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
|
||||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
|
||||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
|
||||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
|
||||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
|
||||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
|
||||||
|
|
||||||
template<typename key_t, int value_size>
|
template<typename key_t, int value_size>
|
||||||
void radix_sort_pairs_impl(
|
void radix_sort_pairs_impl(
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
#include <ATen/cuda/Exceptions.h>
|
#include <ATen/cuda/Exceptions.h>
|
||||||
|
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
@ -10,36 +12,39 @@ __device__ __constant__ float cublas_one_device;
|
|||||||
__device__ __constant__ float cublas_zero_device;
|
__device__ __constant__ float cublas_zero_device;
|
||||||
|
|
||||||
float *get_cublas_device_one() {
|
float *get_cublas_device_one() {
|
||||||
static float *ptr = nullptr;
|
static c10::once_flag init_flag;
|
||||||
static auto init_flag = [&]() {
|
|
||||||
|
c10::call_once(init_flag, []() {
|
||||||
const float one = 1.f;
|
const float one = 1.f;
|
||||||
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
|
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
|
||||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
|
});
|
||||||
return true;
|
|
||||||
}();
|
|
||||||
|
|
||||||
|
float *ptr;
|
||||||
|
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
float *get_cublas_device_zero() {
|
float *get_cublas_device_zero() {
|
||||||
static float *ptr = nullptr;
|
static c10::once_flag init_flag;
|
||||||
static auto init_flag = [&]() {
|
|
||||||
|
c10::call_once(init_flag, []() {
|
||||||
const float zero = 0.f;
|
const float zero = 0.f;
|
||||||
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
|
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
|
||||||
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
|
});
|
||||||
return true;
|
|
||||||
}();
|
|
||||||
|
|
||||||
|
float *ptr;
|
||||||
|
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
float *get_user_alpha_ptr() {
|
float *get_user_alpha_ptr() {
|
||||||
static float *alpha_ptr;
|
static float *alpha_ptr;
|
||||||
|
|
||||||
static bool init_flag [[maybe_unused]] = []() {
|
static c10::once_flag init_flag;
|
||||||
|
|
||||||
|
c10::call_once(init_flag, []() {
|
||||||
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
|
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
|
||||||
return true;
|
});
|
||||||
}();
|
|
||||||
|
|
||||||
return alpha_ptr;
|
return alpha_ptr;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -21,7 +21,6 @@
|
|||||||
|
|
||||||
#if AT_CUDNN_ENABLED()
|
#if AT_CUDNN_ENABLED()
|
||||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||||
#include <cudnn_frontend.h>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if AT_MAGMA_ENABLED()
|
#if AT_MAGMA_ENABLED()
|
||||||
@ -352,26 +351,6 @@ long CUDAHooks::versionCuDNN() const {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
long CUDAHooks::versionRuntimeCuDNN() const {
|
|
||||||
#if AT_CUDNN_ENABLED()
|
|
||||||
#ifndef USE_STATIC_CUDNN
|
|
||||||
return cudnnGetVersion();
|
|
||||||
#else
|
|
||||||
return CUDNN_VERSION;
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
long CUDAHooks::versionCuDNNFrontend() const {
|
|
||||||
#if AT_CUDNN_ENABLED()
|
|
||||||
return CUDNN_FRONTEND_VERSION;
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
long CUDAHooks::versionMIOpen() const {
|
long CUDAHooks::versionMIOpen() const {
|
||||||
#if AT_ROCM_ENABLED()
|
#if AT_ROCM_ENABLED()
|
||||||
return MIOPEN_VERSION_MAJOR * 10000 +
|
return MIOPEN_VERSION_MAJOR * 10000 +
|
||||||
|
|||||||
@ -49,8 +49,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
|||||||
bool hasCUDART() const override;
|
bool hasCUDART() const override;
|
||||||
long versionCUDART() const override;
|
long versionCUDART() const override;
|
||||||
long versionCuDNN() const override;
|
long versionCuDNN() const override;
|
||||||
long versionRuntimeCuDNN() const override;
|
|
||||||
long versionCuDNNFrontend() const override;
|
|
||||||
long versionMIOpen() const override;
|
long versionMIOpen() const override;
|
||||||
std::string showConfig() const override;
|
std::string showConfig() const override;
|
||||||
double batchnormMinEpsilonCuDNN() const override;
|
double batchnormMinEpsilonCuDNN() const override;
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <c10/util/irange.h>
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
@ -137,9 +136,9 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
|
|||||||
"Weight strides: ", t.strides(), "\n",
|
"Weight strides: ", t.strides(), "\n",
|
||||||
"cuDNN suggested memory_format: ", memory_format);
|
"cuDNN suggested memory_format: ", memory_format);
|
||||||
|
|
||||||
std::array<int, CUDNN_DIM_MAX> size;
|
int size[CUDNN_DIM_MAX];
|
||||||
for (const auto i : c10::irange(dim)) {
|
for (const auto i : c10::irange(dim)) {
|
||||||
size[i] = static_cast<int>(t.size(i));
|
size[i] = (int) t.size(i);
|
||||||
}
|
}
|
||||||
for (const auto i : c10::irange(dim, pad)) {
|
for (const auto i : c10::irange(dim, pad)) {
|
||||||
size[i] = 1;
|
size[i] = 1;
|
||||||
@ -157,7 +156,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
|
|||||||
default:
|
default:
|
||||||
TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters");
|
TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters");
|
||||||
}
|
}
|
||||||
set(getDataType(t), static_cast<int>(dim), size.data(), filter_format);
|
set(getDataType(t), static_cast<int>(dim), size, filter_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
|
std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user