Compare commits

..

3 Commits

Author SHA1 Message Date
bdc1181606 more path stuff 2025-11-04 09:57:11 -08:00
f574505205 some path stuff 2025-11-04 09:26:28 -08:00
4dd22ac43c upload jsons while running 2025-11-04 08:46:17 -08:00
1122 changed files with 12453 additions and 34401 deletions

View File

@ -13,4 +13,3 @@ exclude:
- "**/benchmarks/**" - "**/benchmarks/**"
- "**/test_*.py" - "**/test_*.py"
- "**/*_test.py" - "**/*_test.py"
- "tools/**"

View File

@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8 ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8 ENV LANGUAGE en_US.UTF-8
ARG DEVTOOLSET_VERSION=13 ARG DEVTOOLSET_VERSION=11
RUN yum -y update RUN yum -y update
RUN yum -y install epel-release RUN yum -y install epel-release
# install glibc-langpack-en make sure en_US.UTF-8 locale is available # install glibc-langpack-en make sure en_US.UTF-8 locale is available
RUN yum -y install glibc-langpack-en RUN yum -y install glibc-langpack-en
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
# Just add everything as a safe.directory for git since these will be used in multiple places with git # Just add everything as a safe.directory for git since these will be used in multiple places with git
RUN git config --global --add safe.directory '*' RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
@ -41,7 +41,6 @@ RUN bash ./install_conda.sh && rm install_conda.sh
# Install CUDA # Install CUDA
FROM base as cuda FROM base as cuda
ARG CUDA_VERSION=12.6 ARG CUDA_VERSION=12.6
ARG DEVTOOLSET_VERSION=13
RUN rm -rf /usr/local/cuda-* RUN rm -rf /usr/local/cuda-*
ADD ./common/install_cuda.sh install_cuda.sh ADD ./common/install_cuda.sh install_cuda.sh
COPY ./common/install_nccl.sh install_nccl.sh COPY ./common/install_nccl.sh install_nccl.sh
@ -51,8 +50,7 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
# Preserve CUDA_VERSION for the builds # Preserve CUDA_VERSION for the builds
ENV CUDA_VERSION=${CUDA_VERSION} ENV CUDA_VERSION=${CUDA_VERSION}
# Make things in our path by default # Make things in our path by default
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
FROM cuda as cuda12.6 FROM cuda as cuda12.6
RUN bash ./install_cuda.sh 12.6 RUN bash ./install_cuda.sh 12.6
@ -70,22 +68,8 @@ FROM cuda as cuda13.0
RUN bash ./install_cuda.sh 13.0 RUN bash ./install_cuda.sh 13.0
ENV DESIRED_CUDA=13.0 ENV DESIRED_CUDA=13.0
FROM ${ROCM_IMAGE} as rocm_base FROM ${ROCM_IMAGE} as rocm
ARG DEVTOOLSET_VERSION=13
ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
# Install devtoolset on ROCm base image
RUN yum -y update && \
yum -y install epel-release && \
yum -y install glibc-langpack-en && \
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
FROM rocm_base as rocm
ARG PYTORCH_ROCM_ARCH ARG PYTORCH_ROCM_ARCH
ARG DEVTOOLSET_VERSION=13
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
ADD ./common/install_mkl.sh install_mkl.sh ADD ./common/install_mkl.sh install_mkl.sh
RUN bash ./install_mkl.sh && rm install_mkl.sh RUN bash ./install_mkl.sh && rm install_mkl.sh
@ -104,7 +88,6 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
# Final step # Final step
FROM ${BASE_TARGET} as final FROM ${BASE_TARGET} as final
ARG DEVTOOLSET_VERSION=13
COPY --from=openssl /opt/openssl /opt/openssl COPY --from=openssl /opt/openssl /opt/openssl
COPY --from=patchelf /patchelf /usr/local/bin/patchelf COPY --from=patchelf /patchelf /usr/local/bin/patchelf
COPY --from=conda /opt/conda /opt/conda COPY --from=conda /opt/conda /opt/conda

View File

@ -63,7 +63,7 @@ docker build \
--target final \ --target final \
--progress plain \ --progress plain \
--build-arg "BASE_TARGET=${BASE_TARGET}" \ --build-arg "BASE_TARGET=${BASE_TARGET}" \
--build-arg "DEVTOOLSET_VERSION=13" \ --build-arg "DEVTOOLSET_VERSION=11" \
${EXTRA_BUILD_ARGS} \ ${EXTRA_BUILD_ARGS} \
-t ${tmp_tag} \ -t ${tmp_tag} \
$@ \ $@ \

View File

@ -195,16 +195,13 @@ case "$tag" in
NINJA_VERSION=1.9.0 NINJA_VERSION=1.9.0
TRITON=yes TRITON=yes
;; ;;
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks) pytorch-linux-jammy-xpu-n-py3)
ANACONDA_PYTHON_VERSION=3.10 ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=11 GCC_VERSION=11
VISION=yes VISION=yes
XPU_VERSION=2025.2 XPU_VERSION=2025.2
NINJA_VERSION=1.9.0 NINJA_VERSION=1.9.0
TRITON=yes TRITON=yes
if [[ $tag =~ "benchmarks" ]]; then
INDUCTOR_BENCHMARKS=yes
fi
;; ;;
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks) pytorch-linux-jammy-py3-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10 ANACONDA_PYTHON_VERSION=3.10
@ -261,9 +258,9 @@ case "$tag" in
PYTHON_VERSION=3.10 PYTHON_VERSION=3.10
CUDA_VERSION=12.8.1 CUDA_VERSION=12.8.1
;; ;;
pytorch-linux-jammy-aarch64-py3.10-gcc13) pytorch-linux-jammy-aarch64-py3.10-gcc11)
ANACONDA_PYTHON_VERSION=3.10 ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=13 GCC_VERSION=11
ACL=yes ACL=yes
VISION=yes VISION=yes
OPENBLAS=yes OPENBLAS=yes
@ -271,19 +268,9 @@ case "$tag" in
# from pytorch/llvm:9.0.1 is x86 specific # from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes SKIP_LLVM_SRC_BUILD_INSTALL=yes
;; ;;
pytorch-linux-jammy-aarch64-py3.10-clang21) pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10 ANACONDA_PYTHON_VERSION=3.10
CLANG_VERSION=21 GCC_VERSION=11
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=13
ACL=yes ACL=yes
VISION=yes VISION=yes
OPENBLAS=yes OPENBLAS=yes

View File

@ -1 +1 @@
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd

View File

@ -3,7 +3,7 @@
set -eux set -eux
ACL_VERSION=${ACL_VERSION:-"v52.6.0"} ACL_VERSION=${ACL_VERSION:-"v25.02"}
ACL_INSTALL_DIR="/acl" ACL_INSTALL_DIR="/acl"
# Clone ACL # Clone ACL

View File

@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
# work around ubuntu apt-get conflicts # work around ubuntu apt-get conflicts
sudo apt-get -y -f install sudo apt-get -y -f install
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
if [[ $CLANG_VERSION -ge 18 ]]; then if [[ $CLANG_VERSION == 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main" apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
fi fi
fi fi

View File

@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then
# Need the official toolchain repo to get alternate packages # Need the official toolchain repo to get alternate packages
add-apt-repository ppa:ubuntu-toolchain-r/test add-apt-repository ppa:ubuntu-toolchain-r/test
apt-get update apt-get update
apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION apt-get install -y g++-$GCC_VERSION
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50
# Cleanup package manager # Cleanup package manager
apt-get autoclean && apt-get clean apt-get autoclean && apt-get clean

View File

@ -1,56 +0,0 @@
#!/bin/bash
# Script used only in CD pipeline
set -ex
# install dependencies
dnf -y install gmp-devel libmpc-devel texinfo flex bison
cd /usr/local/src
# fetch source for gcc 13
git clone --depth 1 --single-branch -b releases/gcc-13.3.0 https://github.com/gcc-mirror/gcc.git gcc-13.3.0
mkdir -p gcc-13.3.0/build-gomp
cd gcc-13.3.0/build-gomp
# configure gcc build
# I got these flags by:
# 1. downloading the source rpm for gcc-11 on AlmaLinux 8 container
# dnf install -y dnf-plugins-core rpmdevtools
# dnf download --source libgomp
# 2. extracting the gcc.spec from the source.
# rpmdev-extract gcc-xx.src.rpm
# 3. extracting optflags and ld_flags from gcc.spec:
# rpm --eval '%{optflags}'
# rpm --eval '%{build_ldflags}'
#
# I had to remove the following flags because they didn't compile for this version of libgomp:
# -Werror=format-security
# -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1
# -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1
#
# I added -march=armv8-a -mtune=generic to make them explicit. I don't think they're strictly needed.
OPT_FLAGS='-O2 -march=armv8-a -mtune=generic'\
' -fexceptions -g -grecord-gcc-switches -pipe -Wall'\
' -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS'\
' -fstack-protector-strong -fasynchronous-unwind-tables'\
' -fstack-clash-protection'
LDFLAGS='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now'
CFLAGS="$OPT_FLAGS" \
CXXFLAGS="$OPT_FLAGS" \
LDFLAGS="$LDFLAGS" \
../configure \
--prefix=/usr \
--libdir=/usr/lib64 \
--enable-languages=c,c++ \
--disable-multilib \
--disable-bootstrap \
--enable-libgomp
# only build libgomp
make -j$(nproc) all-target-libgomp
make install-target-libgomp

View File

@ -10,7 +10,6 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
OPENBLAS_CHECKOUT_DIR="OpenBLAS" OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS=" OPENBLAS_BUILD_FLAGS="
CC=gcc
NUM_THREADS=128 NUM_THREADS=128
USE_OPENMP=1 USE_OPENMP=1
NO_SHARED=0 NO_SHARED=0

View File

@ -12,8 +12,8 @@ function do_install() {
rocm_version_nodot=${rocm_version//./} rocm_version_nodot=${rocm_version//./}
# post merge of https://github.com/icl-utk-edu/magma/pull/65 # https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2" magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2"
rocm_dir="/opt/rocm" rocm_dir="/opt/rocm"

View File

@ -149,7 +149,7 @@ FROM cpu_final as rocm_final
ARG ROCM_VERSION=6.0 ARG ROCM_VERSION=6.0
ARG PYTORCH_ROCM_ARCH ARG PYTORCH_ROCM_ARCH
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
ARG DEVTOOLSET_VERSION=13 ARG DEVTOOLSET_VERSION=11
ENV LDFLAGS="-Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64 -Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib" ENV LDFLAGS="-Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64 -Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib"
# Somewhere in ROCm stack, we still use non-existing /opt/rocm/hip path, # Somewhere in ROCm stack, we still use non-existing /opt/rocm/hip path,
# below workaround helps avoid error # below workaround helps avoid error

View File

@ -50,10 +50,6 @@ RUN rm install_ninja.sh
ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
# Build a newer version of libgomp than that supported in in Almalinux 8.
COPY ./common/install_libgomp.sh install_libgomp.sh
RUN bash ./install_libgomp.sh && rm install_libgomp.sh
# git236+ would refuse to run git commands in repos owned by other users # git236+ would refuse to run git commands in repos owned by other users
# Which causes version check to fail, as pytorch repo is bind-mounted into the image # Which causes version check to fail, as pytorch repo is bind-mounted into the image
# Override this behaviour by treating every folder as safe # Override this behaviour by treating every folder as safe

View File

@ -97,7 +97,7 @@ case ${image} in
manylinux2_28-builder:xpu) manylinux2_28-builder:xpu)
TARGET=xpu_final TARGET=xpu_final
GPU_IMAGE=amd64/almalinux:8 GPU_IMAGE=amd64/almalinux:8
DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13" DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=11"
MANY_LINUX_VERSION="2_28" MANY_LINUX_VERSION="2_28"
;; ;;
*) *)

View File

@ -1,11 +1,15 @@
sphinx==7.2.6 sphinx==5.3.0
#Description: This is used to generate PyTorch docs #Description: This is used to generate PyTorch docs
#Pinned versions: 7.2.6 #Pinned versions: 5.3.0
pytorch_sphinx_theme2==0.2.0 standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed to generate PyTorch docs #Description: This is needed by Sphinx, so it needs to be added here.
#Pinned versions: 0.2.0 # The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering # TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably # but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later. # something related to Docker setup. We can investigate this later.
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs #Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0 #Pinned versions: 2.13.0
breathe==4.36.0 breathe==4.34.0
#Description: This is used to generate PyTorch C++ docs #Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.36.0 #Pinned versions: 4.34.0
exhale==0.3.7 exhale==0.2.3
#Description: This is used to generate PyTorch C++ docs #Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.3.7 #Pinned versions: 0.2.3
docutils==0.20 docutils==0.16
#Description: This is used to generate PyTorch C++ docs #Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.20 #Pinned versions: 0.16
bs4==0.0.1 bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs #Description: This is used to generate PyTorch C++ docs
@ -52,13 +56,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs #Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0 #Pinned versions: 8.12.0
myst-nb==1.3.0 myst-nb==0.17.2
#Description: This is used to generate PyTorch functorch and torch.compile docs. #Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 1.3.0 #Pinned versions: 0.17.2
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs # The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5 python-etcd==0.4.5
sphinx-copybutton==0.5.0 sphinx-copybutton==0.5.0
sphinx-design==0.6.1 sphinx-design==0.4.0
sphinxcontrib-mermaid==1.0.0 sphinxcontrib-mermaid==1.0.0
myst-parser==4.0.1 myst-parser==0.18.1

View File

@ -1 +1 @@
3.5.1 3.5.0

View File

@ -54,15 +54,12 @@ ENV OPENSSL_DIR /opt/openssl
RUN rm install_openssl.sh RUN rm install_openssl.sh
ARG INDUCTOR_BENCHMARKS ARG INDUCTOR_BENCHMARKS
ARG ANACONDA_PYTHON_VERSION
ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION
COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh
COPY ./common/common_utils.sh common_utils.sh COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt
COPY ci_commit_pins/timm.txt timm.txt COPY ci_commit_pins/timm.txt timm.txt
COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt
# Install XPU Dependencies # Install XPU Dependencies
ARG XPU_VERSION ARG XPU_VERSION

View File

@ -1,7 +1,7 @@
SHELL=/usr/bin/env bash SHELL=/usr/bin/env bash
DOCKER_CMD ?= docker DOCKER_CMD ?= docker
DESIRED_ROCM ?= 7.1 DESIRED_ROCM ?= 7.0
DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM)) DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM))
PACKAGE_NAME = magma-rocm PACKAGE_NAME = magma-rocm
# inherit this from underlying docker image, do not pass this env var to docker # inherit this from underlying docker image, do not pass this env var to docker
@ -16,7 +16,6 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \
magma-rocm/build_magma.sh magma-rocm/build_magma.sh
.PHONY: all .PHONY: all
all: magma-rocm71
all: magma-rocm70 all: magma-rocm70
all: magma-rocm64 all: magma-rocm64
@ -25,11 +24,6 @@ clean:
$(RM) -r magma-* $(RM) -r magma-*
$(RM) -r output $(RM) -r output
.PHONY: magma-rocm71
magma-rocm71: DESIRED_ROCM := 7.1
magma-rocm71:
$(DOCKER_RUN)
.PHONY: magma-rocm70 .PHONY: magma-rocm70
magma-rocm70: DESIRED_ROCM := 7.0 magma-rocm70: DESIRED_ROCM := 7.0
magma-rocm70: magma-rocm70:

View File

@ -426,7 +426,7 @@ fi
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
# export test times so that potential sharded tests that'll branch off this build will use consistent data # export test times so that potential sharded tests that'll branch off this build will use consistent data
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build # don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
PYTHONPATH=. python tools/stats/export_test_times.py python tools/stats/export_test_times.py
fi fi
# don't do this for bazel or s390x or riscv64 as they don't use sccache # don't do this for bazel or s390x or riscv64 as they don't use sccache
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then

View File

@ -89,41 +89,23 @@ if [ "$is_main_doc" = true ]; then
make coverage make coverage
# Now we have the coverage report, we need to make sure it is empty. # Now we have the coverage report, we need to make sure it is empty.
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row # Count the number of lines in the file and turn that number into a variable
# showing the undocumented count in the third column. # $lines. The `cut -f1 ...` is to only parse the number, not the filename
# Example: | TOTAL | 99.83% | 2 | # Skip the report header by subtracting 2: the header will be output even if
# there are no undocumented items.
# #
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should # Also: see docs/source/conf.py for "coverage_ignore*" items, which should
# be documented then removed from there. # be documented then removed from there.
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table undocumented=$((lines - 2))
# The table format is: | Module | Coverage | Undocumented | if [ $undocumented -lt 0 ]; then
# Extract the third column (undocumented count) from the TOTAL row
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
echo coverage output not found echo coverage output not found
exit 1 exit 1
elif [ "$undocumented" -gt 0 ]; then elif [ $undocumented -gt 0 ]; then
set +x # Disable command echoing for cleaner output echo undocumented objects found:
echo "" cat build/coverage/python.txt
echo "====================="
echo "UNDOCUMENTED OBJECTS:"
echo "====================="
echo ""
# Find the line number of the TOTAL row and print only what comes after it
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
if [ -n "$total_line" ]; then
# Print only the detailed list (skip the statistics table)
tail -n +$((total_line + 2)) build/coverage/python.txt
else
# Fallback to showing entire file if TOTAL line not found
cat build/coverage/python.txt
fi
echo ""
echo "Make sure you've updated relevant .rsts in docs/source!" echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'" echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
set -x # Re-enable command echoing
exit 1 exit 1
fi fi
else else

View File

@ -337,7 +337,7 @@ test_python() {
test_python_smoke() { test_python_smoke() {
# Smoke tests for H100/B200 # Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty assert_git_not_dirty
} }
@ -572,8 +572,6 @@ fi
if [[ "${TEST_CONFIG}" == *cpu* ]]; then if [[ "${TEST_CONFIG}" == *cpu* ]]; then
DYNAMO_BENCHMARK_FLAGS+=(--device cpu) DYNAMO_BENCHMARK_FLAGS+=(--device cpu)
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
DYNAMO_BENCHMARK_FLAGS+=(--device xpu)
else else
DYNAMO_BENCHMARK_FLAGS+=(--device cuda) DYNAMO_BENCHMARK_FLAGS+=(--device cuda)
fi fi
@ -667,8 +665,6 @@ test_perf_for_dashboard() {
device=cuda_b200 device=cuda_b200
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
device=rocm device=rocm
elif [[ "${TEST_CONFIG}" == *xpu* ]]; then
device=xpu
fi fi
for mode in "${modes[@]}"; do for mode in "${modes[@]}"; do
@ -1653,7 +1649,7 @@ test_operator_microbenchmark() {
cd "${TEST_DIR}"/benchmarks/operator_benchmark cd "${TEST_DIR}"/benchmarks/operator_benchmark
for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do
$TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \
--output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \ --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \
--benchmark-name "PyTorch operator microbenchmark" --use-compile --benchmark-name "PyTorch operator microbenchmark" --use-compile
@ -1761,7 +1757,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
else else
# Do this after checkout_install_torchbench to ensure we clobber any # Do this after checkout_install_torchbench to ensure we clobber any
# nightlies that torchbench may pull in # nightlies that torchbench may pull in
if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* ]]; then if [[ "${TEST_CONFIG}" != *cpu* ]]; then
install_torchrec_and_fbgemm install_torchrec_and_fbgemm
fi fi
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"

View File

@ -70,7 +70,7 @@ sccache --zero-stats
sccache --show-stats sccache --show-stats
# Build the wheel # Build the wheel
python -m build --wheel --no-isolation python -m build --wheel --no-build-isolation
if ($LASTEXITCODE -ne 0) { exit 1 } if ($LASTEXITCODE -ne 0) { exit 1 }
# Install the wheel locally # Install the wheel locally

View File

@ -60,11 +60,9 @@ performance-*,
readability-container-size-empty, readability-container-size-empty,
readability-delete-null-pointer, readability-delete-null-pointer,
readability-duplicate-include, readability-duplicate-include,
readability-named-parameter,
readability-misplaced-array-index, readability-misplaced-array-index,
readability-redundant*, readability-redundant*,
readability-simplify-subscript-expr, readability-simplify-subscript-expr,
readability-static-definition-in-anonymous-namespace
readability-string-compare, readability-string-compare,
-readability-redundant-access-specifiers, -readability-redundant-access-specifiers,
-readability-redundant-control-flow, -readability-redundant-control-flow,

View File

@ -1,319 +0,0 @@
---
name: add-uint-support
description: Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support.
---
# Add Unsigned Integer (uint) Support to Operators
This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros.
## When to use this skill
Use this skill when:
- Adding uint16, uint32, or uint64 support to an operator
- User mentions "unsigned types", "uint support", "barebones unsigned types"
- Enabling support for kUInt16, kUInt32, kUInt64 in kernels
- Working with operator implementations that need expanded type coverage
## Quick reference
**Add unsigned types to existing dispatch:**
```cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
// After (method 1: add unsigned types explicitly)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present)
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
```
## Type group reference
**Unsigned type groups:**
- `AT_BAREBONES_UNSIGNED_TYPES`: kUInt16, kUInt32, kUInt64
- `AT_INTEGRAL_TYPES_V2`: AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES
**Relationship:**
```cpp
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES
```
## Instructions
### Step 1: Determine if conversion to V2 is needed
Check if the file uses AT_DISPATCH_V2:
**If using old AT_DISPATCH:**
- First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill
- Then proceed with adding uint support
**If already using AT_DISPATCH_V2:**
- Proceed directly to Step 2
### Step 2: Analyze the current dispatch macro
Identify what type groups are currently in use:
```cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
// body
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
^^^^^^^^^^^^^^^^^^^^^^^^^
Current type coverage
```
Common patterns:
- `AT_EXPAND(AT_ALL_TYPES)` → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES
- `AT_EXPAND(AT_INTEGRAL_TYPES)` → signed integers only
- `AT_EXPAND(AT_FLOATING_TYPES)` → floating point types
### Step 3: Choose the uint addition method
Two approaches:
**Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly**
- Use when: You want to be explicit about adding uint support
- Add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the type list
**Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2**
- Use when: The dispatch already uses `AT_EXPAND(AT_INTEGRAL_TYPES)`
- More concise: replaces one type group with its superset
- Only applicable if AT_INTEGRAL_TYPES is present
### Step 4: Apply the transformation
**Method 1 example:**
```cpp
// Before
AT_DISPATCH_V2(
dtype,
"min_values_cuda",
AT_WRAP([&]() {
kernel_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);
// After (add unsigned types)
AT_DISPATCH_V2(
dtype,
"min_values_cuda",
AT_WRAP([&]() {
kernel_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
kBFloat16, kHalf, kBool
);
```
**Method 2 example:**
```cpp
// Before
AT_DISPATCH_V2(
dtype,
"integral_op",
AT_WRAP([&]() {
kernel<scalar_t>();
}),
AT_EXPAND(AT_INTEGRAL_TYPES)
);
// After (substitute with V2)
AT_DISPATCH_V2(
dtype,
"integral_op",
AT_WRAP([&]() {
kernel<scalar_t>();
}),
AT_EXPAND(AT_INTEGRAL_TYPES_V2)
);
```
### Step 5: Handle AT_ALL_TYPES vs individual type groups
If the dispatch uses `AT_EXPAND(AT_ALL_TYPES)`:
- `AT_ALL_TYPES` = `AT_INTEGRAL_TYPES` + `AT_FLOATING_TYPES`
- To add uint: add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the list
If the dispatch separately lists INTEGRAL and FLOATING:
```cpp
// Before
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
// After (Method 2 preferred)
AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)
```
### Step 6: Verify all dispatch sites
Check the file for ALL dispatch macros that need uint support:
- Some operators have multiple dispatch sites (CPU, CUDA, different functions)
- Apply the transformation consistently across all sites
- Ensure each gets the same type coverage updates
### Step 7: Validate the changes
Check that:
- [ ] AT_DISPATCH_V2 format is used (not old AT_DISPATCH)
- [ ] Unsigned types are added via one of the two methods
- [ ] All relevant dispatch sites in the file are updated
- [ ] Type groups use `AT_EXPAND()`
- [ ] Arguments are properly formatted and comma-separated
## Common patterns
### Pattern 1: AT_ALL_TYPES + extras
```cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
```
### Pattern 2: Separate INTEGRAL + FLOATING
```cpp
// Before
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES));
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES));
```
### Pattern 3: Old dispatch needs conversion first
```cpp
// Before (needs v2 conversion first)
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>();
});
// After v2 conversion
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
// After adding uint support
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16);
```
## Multiple dispatch sites example
For a file with multiple functions:
```cpp
void min_values_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() {
impl<scalar_t>(iter);
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// Added uint support
}
void min_launch_kernel(TensorIterator &iter) {
AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t>(iter);
}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf);
// ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
// Added uint support here too
}
```
## Decision tree
Use this decision tree to determine the approach:
```
Is the file using AT_DISPATCH_V2?
├─ No → Use at-dispatch-v2 skill first, then continue
└─ Yes
└─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)?
├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2)
└─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list
```
## Edge cases
### Case 1: Dispatch with only floating types
If the operator only supports floating point types, don't add uint support:
```cpp
// Leave as-is - floating point only operator
AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf);
```
### Case 2: Complex types present
Unsigned types work alongside complex types:
```cpp
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kHalf, kBFloat16);
```
### Case 3: Already has uint support
Check if uint types are already present:
- If `AT_INTEGRAL_TYPES_V2` is used → already has uint support
- If `AT_BAREBONES_UNSIGNED_TYPES` is already in list → already has uint support
- Skip the file if uint support is already present
## Workflow
When asked to add uint support:
1. Read the target file
2. Check if using AT_DISPATCH_V2:
- If not → use at-dispatch-v2 skill first
3. Identify all dispatch macro sites
4. For each dispatch:
- Analyze current type groups
- Choose method (add BAREBONES_UNSIGNED or upgrade to V2)
- Apply transformation with Edit tool
5. Show the user the changes
6. Explain what was modified
## Important notes
- Always check if v2 conversion is needed first
- Apply changes consistently across all dispatch sites in the file
- Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable
- Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit
- Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8)
- Some operators may not semantically support unsigned types - use judgment
## Testing
After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing.

View File

@ -1,305 +0,0 @@
---
name: at-dispatch-v2
description: Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations.
---
# AT_DISPATCH to AT_DISPATCH_V2 Converter
This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in `aten/src/ATen/Dispatch_v2.h`.
## When to use this skill
Use this skill when:
- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2
- Porting ATen kernels to use the new dispatch API
- Working with files in `aten/src/ATen/native/` that use dispatch macros
- User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion
## Quick reference
**Old format:**
```cpp
AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() {
// lambda body
});
```
**New format:**
```cpp
AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() {
// lambda body
}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool);
```
## Key transformations
1. **Reorder arguments**: `scalar_type` and `name` come first, then lambda, then types
2. **Wrap the lambda**: Use `AT_WRAP(lambda)` to handle internal commas
3. **Expand type groups**: Use `AT_EXPAND(AT_ALL_TYPES)` instead of implicit expansion
4. **List individual types**: Add extra types (kHalf, kBFloat16, etc.) after expanded groups
5. **Add include**: `#include <ATen/Dispatch_v2.h>` near other Dispatch includes
## Instructions
### Step 1: Add the Dispatch_v2.h include
Add the v2 header near the existing `#include <ATen/Dispatch.h>`:
```cpp
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
```
Keep the old Dispatch.h include for now (other code may still need it).
### Step 2: Identify the old dispatch pattern
Common patterns to convert:
- `AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)`
- `AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)`
- `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)`
- `AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)`
### Step 3: Map the old macro to type groups
Identify which type group macro corresponds to the base types:
| Old macro base | AT_DISPATCH_V2 type group |
|----------------|---------------------------|
| `ALL_TYPES` | `AT_EXPAND(AT_ALL_TYPES)` |
| `FLOATING_TYPES` | `AT_EXPAND(AT_FLOATING_TYPES)` |
| `INTEGRAL_TYPES` | `AT_EXPAND(AT_INTEGRAL_TYPES)` |
| `COMPLEX_TYPES` | `AT_EXPAND(AT_COMPLEX_TYPES)` |
| `ALL_TYPES_AND_COMPLEX` | `AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)` |
For combined patterns, use multiple `AT_EXPAND()` entries:
```cpp
// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...)
// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2
```
### Step 4: Extract the individual types
From `AT_DISPATCH_*_AND2(type1, type2, ...)` or `AT_DISPATCH_*_AND3(type1, type2, type3, ...)`, extract the individual types (type1, type2, etc.).
These become the trailing arguments after the type group:
```cpp
AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool)
^^^^^^^^^^^^^^^^^^^^^^^^
Individual types from AND3
```
### Step 5: Transform to AT_DISPATCH_V2
Apply the transformation:
**Pattern:**
```cpp
AT_DISPATCH_V2(
scalar_type, // 1st: The dtype expression
"name", // 2nd: The debug string
AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP
type_groups, // 4th+: Type groups with AT_EXPAND()
individual_types // Last: Individual types
)
```
**Example transformation:**
```cpp
// BEFORE
AT_DISPATCH_ALL_TYPES_AND3(
kBFloat16, kHalf, kBool,
iter.dtype(),
"min_values_cuda",
[&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}
);
// AFTER
AT_DISPATCH_V2(
iter.dtype(),
"min_values_cuda",
AT_WRAP([&]() {
min_values_kernel_cuda_impl<scalar_t>(iter);
}),
AT_EXPAND(AT_ALL_TYPES),
kBFloat16, kHalf, kBool
);
```
### Step 6: Handle multi-line lambdas
For lambdas with internal commas or complex expressions, AT_WRAP is essential:
```cpp
AT_DISPATCH_V2(
dtype,
"complex_kernel",
AT_WRAP([&]() {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinOps<scalar_t>{},
thrust::pair<scalar_t, int64_t>(upper_bound(), 0) // Commas inside!
);
}),
AT_EXPAND(AT_ALL_TYPES)
);
```
### Step 7: Verify the conversion
Check that:
- [ ] `AT_WRAP()` wraps the entire lambda
- [ ] Type groups use `AT_EXPAND()`
- [ ] Individual types don't have `AT_EXPAND()` (just `kBFloat16`, not `AT_EXPAND(kBFloat16)`)
- [ ] Argument order is: scalar_type, name, lambda, types
- [ ] Include added: `#include <ATen/Dispatch_v2.h>`
## Type group reference
Available type group macros (use with `AT_EXPAND()`):
```cpp
AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort
AT_FLOATING_TYPES // kDouble, kFloat
AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat
AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32
AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES
AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES
AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types
AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64
AT_FLOAT8_TYPES // Float8 variants
```
## Common patterns
### Pattern: AT_DISPATCH_ALL_TYPES_AND2
```cpp
// Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() {
kernel<scalar_t>(data);
});
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>(data);
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16);
```
### Pattern: AT_DISPATCH_FLOATING_TYPES_AND3
```cpp
// Before
AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn,
tensor.scalar_type(), "float_op", [&] {
process<scalar_t>(tensor);
});
// After
AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] {
process<scalar_t>(tensor);
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn);
```
### Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2
```cpp
// Before
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kComplexHalf, kHalf,
self.scalar_type(),
"complex_op",
[&] {
result = compute<scalar_t>(self);
}
);
// After
AT_DISPATCH_V2(
self.scalar_type(),
"complex_op",
AT_WRAP([&] {
result = compute<scalar_t>(self);
}),
AT_EXPAND(AT_ALL_TYPES),
AT_EXPAND(AT_COMPLEX_TYPES),
kComplexHalf,
kHalf
);
```
## Edge cases
### Case 1: No extra types (rare)
```cpp
// Before
AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel<scalar_t>(); });
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES));
```
### Case 2: Many individual types (AND4, AND5, etc.)
```cpp
// Before
AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2,
dtype, "float8_op", [&]() { kernel<scalar_t>(); });
// After
AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() {
kernel<scalar_t>();
}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2);
```
### Case 3: Lambda with no captures
```cpp
// Before
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() {
static_kernel<scalar_t>();
});
// After
AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() {
static_kernel<scalar_t>();
}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool);
```
## Benefits of AT_DISPATCH_V2
1. **No arity in macro name**: Don't need different macros for AND2, AND3, AND4
2. **Composable type sets**: Mix and match type groups with `AT_EXPAND()`
3. **Extensible**: Easy to add more types without hitting macro limits
4. **Clearer**: Type groups are explicit, not implicit in macro name
## Important notes
- Keep `#include <ATen/Dispatch.h>` - other code may need it
- The `AT_WRAP()` is mandatory - prevents comma parsing issues in the lambda
- Type groups need `AT_EXPAND()`, individual types don't
- The v2 API is in `aten/src/ATen/Dispatch_v2.h` - refer to it for full docs
- See the header file for the Python script to regenerate the macro implementation
## Workflow
When asked to convert AT_DISPATCH macros:
1. Read the file to identify all AT_DISPATCH uses
2. Add `#include <ATen/Dispatch_v2.h>` if not present
3. For each dispatch macro:
- Identify the pattern and extract components
- Map the base type group
- Extract individual types
- Construct the AT_DISPATCH_V2 call
- Apply with Edit tool
4. Show the user the complete converted file
5. Explain what was changed
Do NOT compile or test the code - focus on accurate conversion only.

View File

@ -1,11 +1,11 @@
name: 🚀 New Feature for Release name: 🚀 Release highlight for proposed Feature
description: Submit a Release highlight for proposed Feature description: Submit a Release highlight for proposed Feature
labels: ["release-feature-request"] labels: ["release-feature-request"]
body: body:
- type: textarea - type: textarea
attributes: attributes:
label: New Feature for Release label: Release highlight for proposed Feature
description: > description: >
Example: “A torch.special module, analogous to SciPy's special module.” Example: “A torch.special module, analogous to SciPy's special module.”
- type: input - type: input

View File

@ -27,9 +27,7 @@ runs:
docker system prune -af docker system prune -af
diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //') diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then
diskspace_cutoff_int=$((diskspace_cutoff + 0)) echo "Error: Available diskspace is less than $diskspace_cutoff percent. Not enough diskspace."
difference=$((100 - diskspace_cutoff_int))
echo "Error: Available diskspace is less than $difference percent. Not enough diskspace."
echo "$msg" echo "$msg"
exit 1 exit 1
else else

View File

@ -38,9 +38,9 @@ runs:
run: | run: |
python3 .github/scripts/pytest_cache.py \ python3 .github/scripts/pytest_cache.py \
--download \ --download \
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \ --cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
--pr_identifier "$GITHUB_REF" \ --pr_identifier $GITHUB_REF \
--job_identifier "$JOB_IDENTIFIER" \ --job_identifier $JOB_IDENTIFIER \
--temp_dir "$RUNNER_TEMP" \ --temp_dir $RUNNER_TEMP \
--repo "$REPO" \ --repo $REPO \
--bucket "$BUCKET" \ --bucket $BUCKET \

View File

@ -47,11 +47,11 @@ runs:
run: | run: |
python3 .github/scripts/pytest_cache.py \ python3 .github/scripts/pytest_cache.py \
--upload \ --upload \
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \ --cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
--pr_identifier "$GITHUB_REF" \ --pr_identifier $GITHUB_REF \
--job_identifier "$JOB_IDENTIFIER" \ --job_identifier $JOB_IDENTIFIER \
--sha "$SHA" \ --sha $SHA \
--test_config "$TEST_CONFIG" \ --test_config $TEST_CONFIG \
--shard "$SHARD" \ --shard $SHARD \
--repo "$REPO" \ --repo $REPO \
--temp_dir "$RUNNER_TEMP" \ --temp_dir $RUNNER_TEMP \

View File

@ -1 +1 @@
ad5816f0eee1c873df1b7d371c69f1f811a89387 3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2

View File

@ -1 +1 @@
cfbc5c2f1c798991715a6b06bb3ce46478c4487c 218d2ab791d437309f91e0486eb9fa7f00badc17

View File

@ -1 +1 @@
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9 df6798dfb931ce7c7fe5bed2447cd1092a5981af

View File

@ -1,125 +0,0 @@
# PyTorch Copilot Instructions
This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively.
## Architecture Overview
### Core Components
- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality
- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd
- `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse)
- `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry
- **torch/** - Python bindings and public API
- `torch/csrc/` - C++ Python bindings (hand-written and generated)
- `torch/csrc/autograd/` - Reverse-mode automatic differentiation
- `torch/csrc/jit/` - TorchScript JIT compiler
- **torchgen/** - Code generation tooling that reads `native_functions.yaml`
- **tools/** - Build scripts, autograd derivatives, code generation
### The Code Generation Workflow
**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file:
1. Declares operator signatures, variants (function/method), and dispatch behavior
2. Gets processed by `torchgen/` to generate C++/Python bindings
3. Produces headers in `build/aten/src/ATen/` during compilation
Example entry structure:
```yaml
- func: my_op(Tensor self, Scalar alpha=1) -> Tensor
variants: function, method
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
```
After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`).
## Development Workflows
### Building from Source
**Never run `setup.py` directly** - use pip with editable install:
```bash
python -m pip install --no-build-isolation -v -e .
```
Speed up builds:
- `DEBUG=1` - Debug symbols with `-g -O0`
- `USE_CUDA=0` - Skip CUDA compilation
- `BUILD_TEST=0` - Skip C++ test binaries
- Install `ninja` (`pip install ninja`) for faster builds
- Use `ccache` for incremental compilation caching
Rebuild specific targets: `(cd build && ninja <target>)`
### Testing
**Critical**: DO NOT run entire test suites. Run specific tests only:
```bash
python test/test_torch.py TestTorch.test_specific_case
```
**Test structure**: All tests use `torch.testing._internal.common_utils`:
```python
from torch.testing._internal.common_utils import run_tests, TestCase
class TestFeature(TestCase):
def test_something(self):
# Use self.assertEqual for tensor comparisons
pass
if __name__ == "__main__":
run_tests()
```
**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file.
### Linting
Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes)
## Project-Specific Conventions
### Memory and Storage
- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs)
- CUDA device info lives in storage objects
### Python-C++ Integration (`torch/csrc/`)
- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors
- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr`
- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion
### Dispatch System
- PyTorch uses operator dispatch to route calls to backend-specific kernels
- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops
- See `aten/src/ATen/native/README.md` for dispatch keyword guidance
## Git Workflow (AI Agent Specific)
When preparing PRs from this environment:
```bash
git stash -u
git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch
git stash pop
# Resolve conflicts if necessary
```
## Common Gotchas
1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml`
2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI
3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux)
4. **No internet access** - DO NOT attempt to install dependencies during development
## Key Files Reference
- `AGENTS.md` - Instructions specific to AI coding agents
- `CONTRIBUTING.md` - Comprehensive human contributor guide
- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript)
- `aten/src/ATen/native/README.md` - Operator implementation guide
- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd
## Performance Debugging
Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation.

View File

@ -19,7 +19,6 @@ ciflow_push_tags:
- ciflow/inductor-perf-test-nightly-rocm-mi300 - ciflow/inductor-perf-test-nightly-rocm-mi300
- ciflow/inductor-perf-test-nightly-rocm-mi355 - ciflow/inductor-perf-test-nightly-rocm-mi355
- ciflow/inductor-perf-test-nightly-x86-zen - ciflow/inductor-perf-test-nightly-x86-zen
- ciflow/inductor-perf-test-nightly-xpu
- ciflow/inductor-periodic - ciflow/inductor-periodic
- ciflow/inductor-rocm - ciflow/inductor-rocm
- ciflow/linux-aarch64 - ciflow/linux-aarch64

View File

@ -11,24 +11,18 @@ architectures:
* Latest XPU * Latest XPU
""" """
import json
import os import os
import re
from pathlib import Path
from typing import Optional from typing import Optional
SCRIPT_DIR = Path(__file__).absolute().parent # NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this
REPO_ROOT = SCRIPT_DIR.parent.parent
CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"] CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"]
CUDA_STABLE = "12.8" CUDA_STABLE = "12.8"
CUDA_ARCHES_FULL_VERSION = { CUDA_ARCHES_FULL_VERSION = {
"12.6": "12.6.3", "12.6": "12.6.3",
"12.8": "12.8.1", "12.8": "12.8.1",
"12.9": "12.9.1", "12.9": "12.9.1",
"13.0": "13.0.0", "13.0": "13.0.2",
} }
CUDA_ARCHES_CUDNN_VERSION = { CUDA_ARCHES_CUDNN_VERSION = {
"12.6": "9", "12.6": "9",
@ -37,7 +31,8 @@ CUDA_ARCHES_CUDNN_VERSION = {
"13.0": "9", "13.0": "9",
} }
ROCM_ARCHES = ["7.0", "7.1"] # NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this
ROCM_ARCHES = ["6.4", "7.0"]
XPU_ARCHES = ["xpu"] XPU_ARCHES = ["xpu"]
@ -142,48 +137,9 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
} }
# Used by tools/nightly.py
PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly"
NIGHTLY_SOURCE_MATRIX = {
"cpu": dict(
name="cpu",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu",
supported_platforms=["Linux", "macOS", "Windows"],
accelerator="cpu",
)
}
CUDA_NIGHTLY_SOURCE_MATRIX = {
f"cuda-{major}.{minor}": dict(
name=f"cuda-{major}.{minor}",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu{major}{minor}",
supported_platforms=["Linux", "Windows"],
accelerator="cuda",
)
for major, minor in (map(int, version.split(".")) for version in CUDA_ARCHES)
}
ROCM_NIGHTLY_SOURCE_MATRIX = {
f"rocm-{major}.{minor}": dict(
name=f"rocm-{major}.{minor}",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm{major}.{minor}",
supported_platforms=["Linux"],
accelerator="rocm",
)
for major, minor in (map(int, version.split(".")) for version in ROCM_ARCHES)
}
XPU_NIGHTLY_SOURCE_MATRIX = {
"xpu": dict(
name="xpu",
index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/xpu",
supported_platforms=["Linux"],
accelerator="xpu",
)
}
NIGHTLY_SOURCE_MATRIX.update(CUDA_NIGHTLY_SOURCE_MATRIX)
NIGHTLY_SOURCE_MATRIX.update(ROCM_NIGHTLY_SOURCE_MATRIX)
NIGHTLY_SOURCE_MATRIX.update(XPU_NIGHTLY_SOURCE_MATRIX)
def get_nccl_wheel_version(arch_version: str) -> str: def get_nccl_wheel_version(arch_version: str) -> str:
import re
requirements = map( requirements = map(
str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version]) str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version])
) )
@ -191,14 +147,17 @@ def get_nccl_wheel_version(arch_version: str) -> str:
def read_nccl_pin(arch_version: str) -> str: def read_nccl_pin(arch_version: str) -> str:
nccl_pin_path = ( from pathlib import Path
REPO_ROOT
/ ".ci" nccl_pin_path = os.path.join(
/ "docker" Path(__file__).absolute().parents[2],
/ "ci_commit_pins" ".ci",
/ f"nccl-cu{arch_version[:2]}.txt" "docker",
"ci_commit_pins",
f"nccl-cu{arch_version[:2]}.txt",
) )
return nccl_pin_path.read_text().strip() with open(nccl_pin_path) as f:
return f.read().strip()
def validate_nccl_dep_consistency(arch_version: str) -> None: def validate_nccl_dep_consistency(arch_version: str) -> None:
@ -206,8 +165,7 @@ def validate_nccl_dep_consistency(arch_version: str) -> None:
wheel_ver = get_nccl_wheel_version(arch_version) wheel_ver = get_nccl_wheel_version(arch_version)
if not nccl_release_tag.startswith(f"v{wheel_ver}"): if not nccl_release_tag.startswith(f"v{wheel_ver}"):
raise RuntimeError( raise RuntimeError(
f"{arch_version} NCCL release tag version {nccl_release_tag} " f"{arch_version} NCCL release tag version {nccl_release_tag} does not correspond to wheel version {wheel_ver}"
f"does not correspond to wheel version {wheel_ver}"
) )
@ -454,14 +412,7 @@ def generate_wheels_matrix(
return ret return ret
arch_version = "" validate_nccl_dep_consistency("13.0")
for arch_version in CUDA_ARCHES: validate_nccl_dep_consistency("12.9")
validate_nccl_dep_consistency(arch_version) validate_nccl_dep_consistency("12.8")
del arch_version validate_nccl_dep_consistency("12.6")
if __name__ == "__main__":
# Used by tools/nightly.py
(SCRIPT_DIR / "nightly_source_matrix.json").write_text(
json.dumps(NIGHTLY_SOURCE_MATRIX, indent=4) + "\n"
)

View File

@ -97,8 +97,8 @@ jobs:
shell: bash shell: bash
run: | run: |
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus. if [[ $ngpu -lt 4 ]]; then
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs" echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
exit 1 exit 1
fi fi

View File

@ -38,10 +38,6 @@ on:
default: "" default: ""
description: | description: |
List of tests to include (empty string implies default list) List of tests to include (empty string implies default list)
dashboard-tag:
required: false
type: string
default: ""
disable-monitor: disable-monitor:
description: | description: |
[Experimental] Disable utilization monitoring for tests. [Experimental] Disable utilization monitoring for tests.
@ -62,11 +58,6 @@ on:
required: false required: false
type: number type: number
default: 1 default: 1
secrets:
HUGGING_FACE_HUB_TOKEN:
required: false
description: |
HF Auth token to avoid rate limits when downloading models or datasets from hub
permissions: permissions:
id-token: write id-token: write
contents: read contents: read
@ -205,8 +196,6 @@ jobs:
PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }}
PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }}
TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }} TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }}
DASHBOARD_TAG: ${{ inputs.dashboard-tag }}
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }} timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }}
run: | run: |
# Fetch aws credential from IMDs # Fetch aws credential from IMDs
@ -257,8 +246,6 @@ jobs:
-e PYTORCH_TEST_RERUN_DISABLED_TESTS \ -e PYTORCH_TEST_RERUN_DISABLED_TESTS \
-e TESTS_TO_INCLUDE \ -e TESTS_TO_INCLUDE \
-e ZE_AFFINITY_MASK \ -e ZE_AFFINITY_MASK \
-e HUGGING_FACE_HUB_TOKEN \
-e DASHBOARD_TAG \
--env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \
--ulimit stack=10485760:83886080 \ --ulimit stack=10485760:83886080 \
--ulimit core=0 \ --ulimit core=0 \
@ -344,21 +331,5 @@ jobs:
if-no-files-found: ignore if-no-files-found: ignore
path: ./**/core.[1-9]* path: ./**/core.[1-9]*
- name: Authenticate with AWS
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
# The max duration enforced by the server side
role-duration-seconds: 18000
aws-region: us-east-1
- name: Upload the benchmark results
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
with:
benchmark-results-dir: test/test-reports
dry-run: false
schema-version: v3
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Teardown XPU - name: Teardown XPU
uses: ./.github/actions/teardown-xpu uses: ./.github/actions/teardown-xpu

View File

@ -36,7 +36,7 @@ jobs:
runs-on: linux.9xlarge.ephemeral runs-on: linux.9xlarge.ephemeral
strategy: strategy:
matrix: matrix:
tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm7.0", "rocm7.1", "cpu"] tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"]
steps: steps:
- name: Build docker image - name: Build docker image
uses: pytorch/pytorch/.github/actions/binary-docker-build@main uses: pytorch/pytorch/.github/actions/binary-docker-build@main

View File

@ -52,8 +52,8 @@ jobs:
{ tag: "cuda12.9" }, { tag: "cuda12.9" },
{ tag: "cuda12.8" }, { tag: "cuda12.8" },
{ tag: "cuda12.6" }, { tag: "cuda12.6" },
{ tag: "rocm6.4" },
{ tag: "rocm7.0" }, { tag: "rocm7.0" },
{ tag: "rocm7.1" },
{ tag: "cpu" }, { tag: "cpu" },
] ]
steps: steps:

View File

@ -34,7 +34,7 @@ jobs:
id-token: write id-token: write
strategy: strategy:
matrix: matrix:
rocm_version: ["71", "70"] rocm_version: ["70", "64"]
steps: steps:
- name: Checkout PyTorch - name: Checkout PyTorch
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

View File

@ -54,8 +54,8 @@ jobs:
{ name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "rocm7.1", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" },
{ name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" },
{ name: "manylinux2_28-builder", tag: "xpu", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "xpu", runner: "linux.9xlarge.ephemeral" },

View File

@ -55,7 +55,7 @@ jobs:
docker-image: ["pytorch/manylinux2_28-builder:cpu"] docker-image: ["pytorch/manylinux2_28-builder:cpu"]
include: include:
- device: "rocm" - device: "rocm"
rocm_version: "7.1" rocm_version: "7.0"
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
- device: "cuda" - device: "cuda"
rocm_version: "" rocm_version: ""
@ -159,7 +159,12 @@ jobs:
WITH_CLANG_LDD="--with-clang-ldd" WITH_CLANG_LDD="--with-clang-ldd"
fi fi
docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD" if [[ "${BUILD_DEVICE}" == xpu ]]; then
docker exec -t "${container_name}" bash -c "dnf install -y gcc-toolset-13-gcc-c++"
docker exec -t "${container_name}" bash -c "source /opt/rh/gcc-toolset-13/enable && ${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE"
else
docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD"
fi
if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "xpu") ]]; then if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "xpu") ]]; then
docker exec -t "${container_name}" bash -c "auditwheel repair --plat ${PLATFORM} //artifacts/*.whl" docker exec -t "${container_name}" bash -c "auditwheel repair --plat ${PLATFORM} //artifacts/*.whl"

View File

@ -67,7 +67,6 @@ jobs:
pytorch-linux-jammy-py3.12-halide, pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-xpu-n-1-py3, pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-jammy-xpu-n-py3, pytorch-linux-jammy-xpu-n-py3,
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
pytorch-linux-jammy-py3-clang18-asan, pytorch-linux-jammy-py3-clang18-asan,
pytorch-linux-jammy-py3-clang12-onnx, pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter, pytorch-linux-jammy-linter,
@ -77,11 +76,9 @@ jobs:
pytorch-linux-noble-riscv64-py3.12-gcc14 pytorch-linux-noble-riscv64-py3.12-gcc14
] ]
include: include:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13 - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21 - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600 timeout-minutes: 600
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358 # Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358

View File

@ -8,7 +8,6 @@ on:
- docker.Makefile - docker.Makefile
- .github/workflows/docker-release.yml - .github/workflows/docker-release.yml
- .github/scripts/generate_docker_release_matrix.py - .github/scripts/generate_docker_release_matrix.py
- .github/scripts/generate_binary_build_matrix.py
push: push:
branches: branches:
- nightly - nightly

View File

@ -384,6 +384,124 @@ jobs:
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm6_4-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
timeout-minutes: 300
build_name: libtorch-rocm6_4-shared-with-deps-release
build_environment: linux-binary-libtorch
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-rocm6_4-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-rocm6_4-shared-with-deps-release-build
- get-label-type
runs-on: linux.rocm.gpu.mi250
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-rocm6_4-shared-with-deps-release
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
docker-image-name: libtorch-cxx11-builder
custom-tag-prefix: rocm6.4
docker-build-dir: .ci/docker
working-directory: pytorch
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
env:
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm6_4-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-rocm6_4-shared-with-deps-release-test
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.4
GPU_ARCH_VERSION: "6.4"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
build_name: libtorch-rocm6_4-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm7_0-shared-with-deps-release-build: libtorch-rocm7_0-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }} if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml uses: ./.github/workflows/_binary-build-linux.yml
@ -501,121 +619,3 @@ jobs:
secrets: secrets:
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm7_1-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
timeout-minutes: 300
build_name: libtorch-rocm7_1-shared-with-deps-release
build_environment: linux-binary-libtorch
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-rocm7_1-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- libtorch-rocm7_1-shared-with-deps-release-build
- get-label-type
runs-on: linux.rocm.gpu.mi250
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
permissions:
id-token: write
contents: read
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v4.1.7
name: Download Build Artifacts
with:
name: libtorch-rocm7_1-shared-with-deps-release
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: actions/checkout@v4
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
show-progress: false
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: configure aws credentials
id: aws_creds
if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }}
uses: aws-actions/configure-aws-credentials@v4
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }}
docker-image-name: libtorch-cxx11-builder
custom-tag-prefix: rocm7.1
docker-build-dir: .ci/docker
working-directory: pytorch
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
env:
DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm7_1-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-rocm7_1-shared-with-deps-release-test
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm7.1
GPU_ARCH_VERSION: "7.1"
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: libtorch-cxx11-builder
DOCKER_IMAGE_TAG_PREFIX: rocm7.1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
build_name: libtorch-rocm7_1-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
uses: ./.github/workflows/_binary-upload.yml

File diff suppressed because it is too large Load Diff

View File

@ -72,7 +72,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.arm64.m7g.4xlarge runner: linux.arm64.m7g.4xlarge
build-environment: linux-jammy-aarch64-py3.10 build-environment: linux-jammy-aarch64-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
test-matrix: | test-matrix: |
{ include: [ { include: [
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" }, { config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },

View File

@ -1,148 +0,0 @@
name: inductor-perf-nightly-xpu
on:
push:
tags:
- ciflow/inductor-perf-test-nightly-xpu/*
schedule:
- cron: 30 17 * * *
workflow_dispatch:
inputs:
training:
description: Run training (on by default)?
required: false
type: boolean
default: true
inference:
description: Run inference (on by default)?
required: false
type: boolean
default: true
default:
description: Run inductor_default?
required: false
type: boolean
default: false
dynamic:
description: Run inductor_dynamic_shapes?
required: false
type: boolean
default: false
cppwrapper:
description: Run inductor_cpp_wrapper?
required: false
type: boolean
default: false
cudagraphs:
description: Run inductor_cudagraphs?
required: false
type: boolean
default: false
freezing_cudagraphs:
description: Run inductor_cudagraphs with freezing for inference?
required: false
type: boolean
default: false
aotinductor:
description: Run aot_inductor for inference?
required: false
type: boolean
default: false
maxautotune:
description: Run inductor_max_autotune?
required: false
type: boolean
default: false
benchmark_configs:
description: The list of configs used the benchmark
required: false
type: string
default: inductor_huggingface_perf,inductor_timm_perf,inductor_torchbench_perf,cachebench
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
xpu-n-py3_10-inductor-benchmark-build:
name: xpu-n-py3.10-inductor-benchmark
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_xpu", shard: 1, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 2, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 3, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 4, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_huggingface_perf_xpu", shard: 5, num_shards: 5, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_timm_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" },
{ config: "inductor_torchbench_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" },
]}
secrets: inherit
xpu-n-py3_10-inductor-benchmark-test-nightly:
permissions:
id-token: write
contents: read
if: github.event_name != 'workflow_dispatch'
name: xpu-n-py3.10-inductor-benchmark
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
timeout-minutes: 720
# Disable monitor in perf tests for more investigation
disable-monitor: true
monitor-log-interval: 10
monitor-data-collect-interval: 2
secrets: inherit
xpu-n-py3_10-inductor-benchmark-test:
permissions:
id-token: write
contents: read
if: github.event_name == 'workflow_dispatch'
name: xpu-n-py3.10-inductor-test
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
timeout-minutes: 720
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -1,10 +1,9 @@
name: inductor-rocm name: inductor-rocm
on: on:
schedule:
- cron: 0 */3 * * *
push: push:
branches: branches:
- main
- release/* - release/*
tags: tags:
- ciflow/inductor-rocm/* - ciflow/inductor-rocm/*

View File

@ -115,10 +115,10 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: | test-matrix: |
{ include: [ { include: [
{ config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" }, { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
{ config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" }, { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" },
]} ]}
secrets: inherit secrets: inherit

View File

@ -84,13 +84,13 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: | test-matrix: |
{ include: [ { include: [
{ config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" },
{ config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" },
]} ]}
build-additional-packages: "vision audio torchao" build-additional-packages: "vision audio torchao"

View File

@ -76,12 +76,11 @@ jobs:
# NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes # NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes
# fails to find types when it should # fails to find types when it should
# NOTE: We should be able to disable this and consolidate with Pyrefly lintrunner-mypy:
lintrunner-pyrefly:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
name: lintrunner-pyrefly-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }}
needs: [get-label-type, get-changed-files] needs: [get-label-type, get-changed-files]
# Only run if there are changed files relevant to pyrefly # Only run if there are changed files relevant to mypy
if: | if: |
github.repository_owner == 'pytorch' && ( github.repository_owner == 'pytorch' && (
needs.get-changed-files.outputs.changed-files == '*' || needs.get-changed-files.outputs.changed-files == '*' ||
@ -99,8 +98,8 @@ jobs:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: | script: |
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running pyrefly" echo "Running mypy"
ADDITIONAL_LINTRUNNER_ARGS="--take PYREFLY --all-files" .github/scripts/lintrunner.sh ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh
lintrunner-noclang: lintrunner-noclang:
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
@ -119,9 +118,9 @@ jobs:
CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}"
echo "Running all other linters" echo "Running all other linters"
if [ "$CHANGED_FILES" = '*' ]; then if [ "$CHANGED_FILES" = '*' ]; then
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY --all-files" .github/scripts/lintrunner.sh ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh
else else
ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh
fi fi
quick-checks: quick-checks:

View File

@ -33,7 +33,7 @@ jobs:
with: with:
runner_prefix: ${{ needs.get-label-type.outputs.label-type }} runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-aarch64-py3.10 build-environment: linux-jammy-aarch64-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge runner: linux.arm64.m7g.4xlarge
test-matrix: | test-matrix: |
{ include: [ { include: [

View File

@ -41,7 +41,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml uses: ./.github/workflows/_linux-build.yml
needs: get-label-type needs: get-label-type
with: with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge"
build-environment: linux-jammy-py3.10-gcc11 build-environment: linux-jammy-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11
secrets: inherit secrets: inherit

View File

@ -60,7 +60,7 @@ jobs:
with: with:
build-environment: linux-jammy-aarch64-py3.10 build-environment: linux-jammy-aarch64-py3.10
runner: linux.arm64.m7g.4xlarge runner: linux.arm64.m7g.4xlarge
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
test-matrix: | test-matrix: |
{ include: [ { include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" }, { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },

View File

@ -66,10 +66,10 @@ jobs:
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
]} ]}
secrets: inherit secrets: inherit
@ -167,8 +167,8 @@ jobs:
docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx
test-matrix: | test-matrix: |
{ include: [ { include: [
{ config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
]} ]}
secrets: inherit secrets: inherit

View File

@ -3,14 +3,13 @@ name: rocm
on: on:
push: push:
branches: branches:
- main
- release/* - release/*
tags: tags:
- ciflow/rocm/* - ciflow/rocm/*
workflow_dispatch: workflow_dispatch:
schedule: schedule:
- cron: 29 8 * * * # about 1:29am PDT - cron: 29 8 * * * # about 1:29am PDT
- cron: 0 */3 * * *
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}

View File

@ -204,7 +204,6 @@ jobs:
{ include: [ { include: [
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" }, { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.4" },
]} ]}
secrets: inherit secrets: inherit
@ -222,7 +221,7 @@ jobs:
build-environment: linux-jammy-rocm-py3.10 build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"
secrets: inherit secrets: inherit
inductor-build: inductor-build:

3
.gitignore vendored
View File

@ -127,7 +127,6 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
@ -144,7 +143,6 @@ scripts/release_notes/*.json
sccache-stats*.json sccache-stats*.json
lint.json lint.json
merge_record.json merge_record.json
.github/scripts/nightly_source_matrix.json
# These files get copied over on invoking setup.py # These files get copied over on invoking setup.py
torchgen/packaged/* torchgen/packaged/*
@ -399,4 +397,3 @@ CLAUDE.local.md
/test_*.py /test_*.py
/debug_*.py /debug_*.py
CLAUDE_CONTEXT/ CLAUDE_CONTEXT/
/.claude/settings.local.json

View File

@ -121,6 +121,94 @@ command = [
] ]
is_formatter = true is_formatter = true
[[linter]]
code = 'MYPY'
include_patterns = [
'setup.py',
'functorch/dim/**/*.py',
'torch/**/*.py',
'torch/**/*.pyi',
'caffe2/**/*.py',
'caffe2/**/*.pyi',
'test/test_bundled_images.py',
'test/test_bundled_inputs.py',
'test/test_complex.py',
'test/test_datapipe.py',
'test/test_futures.py',
'test/test_numpy_interop.py',
'test/test_torch.py',
'test/test_type_hints.py',
'test/test_type_info.py',
'test/test_utils.py',
]
exclude_patterns = [
'**/fb/**',
]
command = [
'python3',
'tools/linter/adapters/mypy_linter.py',
'--config=mypy.ini',
'--',
'@{{PATHSFILE}}'
]
init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'mypy==1.16.0',
'sympy==1.13.3',
'types-requests==2.27.25',
'types-pyyaml==6.0.2',
'types-tabulate==0.8.8',
'types-protobuf==5.29.1.20250403',
'types-setuptools==79.0.0.20250422',
'types-jinja2==2.11.9',
'types-colorama==0.4.6',
'filelock==3.18.0',
'junitparser==2.1.1',
'rich==14.1.0',
'pyyaml==6.0.2',
'optree==0.13.0',
'dataclasses-json==0.6.7',
'pandas==2.2.3',
]
[[linter]]
code = 'MYPYSTRICT'
include_patterns = [
'.github/**/*.py',
'benchmarks/instruction_counts/**/*.py',
'tools/**/*.py',
'torchgen/**/*.py',
'torch/utils/_pytree.py',
'torch/utils/_cxx_pytree.py',
'torch/utils/benchmark/utils/common.py',
'torch/utils/benchmark/utils/timer.py',
'torch/utils/benchmark/utils/valgrind_wrapper/**/*.py',
]
exclude_patterns = [
# (linbinyu) copied from internal repo
'**/fb/**',
'tools/code_analyzer/gen_operators_yaml.py',
'tools/dynamo/verify_dynamo.py',
'tools/gen_vulkan_spv.py',
'tools/test/gen_operators_yaml_test.py',
'tools/test/gen_oplist_test.py',
'tools/test/test_selective_build.py',
'tools/experimental/torchfuzz/**',
]
command = [
'python3',
'tools/linter/adapters/mypy_linter.py',
'--config=mypy-strict.ini',
'--code=MYPYSTRICT',
'--',
'@{{PATHSFILE}}'
]
[[linter]] [[linter]]
code = 'PYREFLY' code = 'PYREFLY'
@ -142,9 +230,7 @@ init_command = [
'python3', 'python3',
'tools/linter/adapters/pip_init.py', 'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}', '--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"', 'numpy==2.1.0 ; python_version >= "3.12"',
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
'numpy==2.3.4 ; python_version >= "3.14"',
'expecttest==0.3.0', 'expecttest==0.3.0',
'pyrefly==0.36.2', 'pyrefly==0.36.2',
'sympy==1.13.3', 'sympy==1.13.3',
@ -212,6 +298,7 @@ exclude_patterns = [
'**/*pb.h', '**/*pb.h',
'**/*inl.h', '**/*inl.h',
'aten/src/ATen/cpu/FlushDenormal.cpp', 'aten/src/ATen/cpu/FlushDenormal.cpp',
'aten/src/ATen/cpu/Utils.cpp',
'aten/src/ATen/cpu/vml.h', 'aten/src/ATen/cpu/vml.h',
'aten/src/ATen/CPUFixedAllocator.h', 'aten/src/ATen/CPUFixedAllocator.h',
'aten/src/ATen/Parallel*.h', 'aten/src/ATen/Parallel*.h',
@ -230,6 +317,8 @@ exclude_patterns = [
'c10/util/win32-headers.h', 'c10/util/win32-headers.h',
'c10/test/**/*.h', 'c10/test/**/*.h',
'third_party/**/*', 'third_party/**/*',
'torch/csrc/api/include/torch/nn/modules/common.h',
'torch/csrc/api/include/torch/linalg.h',
'torch/csrc/autograd/generated/**', 'torch/csrc/autograd/generated/**',
'torch/csrc/distributed/**/*.cu', 'torch/csrc/distributed/**/*.cu',
'torch/csrc/distributed/c10d/WinSockUtils.hpp', 'torch/csrc/distributed/c10d/WinSockUtils.hpp',
@ -241,6 +330,7 @@ exclude_patterns = [
'torch/csrc/utils/generated_serialization_types.h', 'torch/csrc/utils/generated_serialization_types.h',
'torch/csrc/utils/pythoncapi_compat.h', 'torch/csrc/utils/pythoncapi_compat.h',
'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h', 'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h',
'aten/src/ATen/ExpandBase.h',
] ]
init_command = [ init_command = [
'python3', 'python3',

View File

@ -234,17 +234,7 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON)
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
option(USE_LSAN "Use Leak Sanitizer" OFF) option(USE_LSAN "Use Leak Sanitizer" OFF)
option(USE_TSAN "Use Thread Sanitizer" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF)
# Track whether USE_CUDA was explicitly set by the user (before option() is called)
# If USE_CUDA is already defined in cache, it means user explicitly set it
if(DEFINED CACHE{USE_CUDA})
set(_USE_CUDA_EXPLICITLY_SET TRUE)
else()
set(_USE_CUDA_EXPLICITLY_SET FALSE)
endif()
option(USE_CUDA "Use CUDA" ON) option(USE_CUDA "Use CUDA" ON)
option(USE_XPU "Use XPU" ON) option(USE_XPU "Use XPU" ON)
cmake_dependent_option( cmake_dependent_option(
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON

View File

@ -11,6 +11,7 @@ aspects of contributing to PyTorch.
<!-- toc --> <!-- toc -->
- [Developing PyTorch](#developing-pytorch) - [Developing PyTorch](#developing-pytorch)
- [Setup the development environment](#setup-the-development-environment)
- [Tips and Debugging](#tips-and-debugging) - [Tips and Debugging](#tips-and-debugging)
- [Nightly Checkout & Pull](#nightly-checkout--pull) - [Nightly Checkout & Pull](#nightly-checkout--pull)
- [Codebase structure](#codebase-structure) - [Codebase structure](#codebase-structure)
@ -18,7 +19,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing) - [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest) - [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting) - [Local linting](#local-linting)
- [Running `pyrefly`](#running-pyrefly) - [Running `mypy`](#running-mypy)
- [C++ Unit Testing](#c-unit-testing) - [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs) - [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change) - [Merging your Change](#merging-your-change)
@ -66,6 +67,23 @@ aspects of contributing to PyTorch.
Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions. Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions.
### Setup the development environment
First, you need to [fork the PyTorch project on GitHub](https://github.com/pytorch/pytorch/fork) and follow the instructions at [Connecting to GitHub with SSH](https://docs.github.com/en/authentication/connecting-to-github-with-ssh) to setup your SSH authentication credentials.
Then clone the PyTorch project and setup the development environment:
```bash
git clone git@github.com:<USERNAME>/pytorch.git
cd pytorch
git remote add upstream git@github.com:pytorch/pytorch.git
make setup-env
# Or run `make setup-env-cuda` for pre-built CUDA binaries
# Or run `make setup-env-rocm` for pre-built ROCm binaries
source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows
```
### Tips and Debugging ### Tips and Debugging
* If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below. * If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below.
@ -281,7 +299,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**: **Prerequisites**:
The following packages should be installed with `pip`: The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests - `expecttest` and `hypothesis` - required to run tests
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/) - `mypy` - recommended for linting
- `pytest` - recommended to run tests more selectively - `pytest` - recommended to run tests more selectively
Running Running
``` ```
@ -350,32 +368,15 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner) Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `pyrefly` #### Running `mypy`
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback. `mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
See [Guide for adding type annotations to See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch) PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase. for more information on how to set up `mypy` and tackle type annotation
tasks.
### C++ Unit Testing ### C++ Unit Testing

View File

@ -1,7 +1,7 @@
# Security Policy # Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability) - [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using PyTorch Securely**](#using-pytorch-securely) - [**Using Pytorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models) - [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models) - [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs) - [Untrusted inputs](#untrusted-inputs)
@ -10,28 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles) - [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues ## Reporting Security Issues
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch. Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework. All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported: Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat https://www.facebook.com/whitehat
## Using PyTorch Securely ## Using Pytorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package). **Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models ### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources]. Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing). **Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. **Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs. Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models ### TorchScript models
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load. TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction ### Untrusted inputs during training and prediction
@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy ### Data privacy
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: **Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment) - Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits). - If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
### Using distributed features ### Using distributed features

View File

@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI)
if(USE_CUDA) if(USE_CUDA)
# To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build. # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build.
# If you want to integrate a kernel from FBGEMM into torch, you have to add it here. # If you want to integrate a kernel from FBGEMM into torch, you have to add it here.
set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*") set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*")
file(GLOB_RECURSE fbgemm_genai_native_cuda_cu file(GLOB_RECURSE fbgemm_genai_native_cuda_cu
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu" "${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu"
"${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu") "${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu")

View File

@ -181,7 +181,7 @@ c10::intrusive_ptr<c10::TensorImpl> CPUGeneratorImpl::get_state() const {
static const size_t size = sizeof(CPUGeneratorImplState); static const size_t size = sizeof(CPUGeneratorImplState);
static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType"); static_assert(std::is_standard_layout_v<CPUGeneratorImplState>, "CPUGeneratorImplState is not a PODType");
auto state_tensor = at::detail::empty_cpu({static_cast<int64_t>(size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr(); auto rng_state = state_tensor.data_ptr();
// accumulate generator data to be copied into byte tensor // accumulate generator data to be copied into byte tensor

View File

@ -23,6 +23,8 @@ C10_DIAGNOSTIC_POP()
#endif #endif
namespace at { namespace at {
namespace {
/* /*
These const variables defined the fp32 precisions for different backend These const variables defined the fp32 precisions for different backend
We have "generic", "cuda", "mkldnn" backend now and we can choose fp32 We have "generic", "cuda", "mkldnn" backend now and we can choose fp32
@ -39,6 +41,16 @@ namespace at {
->rnn ->rnn
*/ */
C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){
TORCH_WARN_ONCE(
"Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' "
"or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, "
"torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see "
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"
);
}
} // namespace
Float32Backend str2backend(const std::string& name) { Float32Backend str2backend(const std::string& name) {
if (name == "generic") if (name == "generic")
return Float32Backend::GENERIC; return Float32Backend::GENERIC;
@ -194,6 +206,7 @@ bool Context::allowTF32CuDNN(std::optional<Float32Op> op) const {
} else { } else {
return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32; return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32;
} }
warn_deprecated_fp32_precision_api();
return allow_tf32_cudnn; return allow_tf32_cudnn;
} }
@ -201,6 +214,7 @@ void Context::setAllowTF32CuDNN(bool b) {
setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE); setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE);
setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE); setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE);
allow_tf32_cudnn = b; allow_tf32_cudnn = b;
warn_deprecated_fp32_precision_api();
} }
void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) { void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
@ -209,7 +223,7 @@ void Context::setSDPPriorityOrder(const std::vector<int64_t>& order) {
"setSDPPriority order expected ", sdp_priority_order.size() - 1, " but got ", "setSDPPriority order expected ", sdp_priority_order.size() - 1, " but got ",
at::num_sdp_backends, " unique backends specified in priority order."); at::num_sdp_backends, " unique backends specified in priority order.");
for (uint32_t i = 0; i < order.size(); i++) { for (uint32_t i = 0; i < order.size(); i++) {
sdp_priority_order[i] = static_cast<at::SDPBackend>(order[i]); sdp_priority_order[i] = (at::SDPBackend) order[i];
} }
} }
@ -311,6 +325,7 @@ bool Context::allowTF32CuBLAS() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ", "Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ",
"We suggest only using the new API to set the TF32 flag. See also: ", "We suggest only using the new API to set the TF32 flag. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return allow_tf32_new; return allow_tf32_new;
} }
@ -334,6 +349,7 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const {
"Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ", "Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ",
"We suggest only using the new API for matmul precision. See also: ", "We suggest only using the new API for matmul precision. See also: ",
"https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices");
warn_deprecated_fp32_precision_api();
return float32_matmul_precision; return float32_matmul_precision;
} }
@ -361,6 +377,7 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op)
void Context::setFloat32MatmulPrecision(const std::string &s) { void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) { auto match = [this](const std::string & s_) {
warn_deprecated_fp32_precision_api();
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") { if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
@ -808,14 +825,6 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) {
display_vmap_fallback_warnings_ = enabled; display_vmap_fallback_warnings_ = enabled;
} }
bool Context::warnOnAccumulateGradStreamMismatch() const {
return warn_on_accumulate_grad_stream_mismatch_;
}
void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) {
warn_on_accumulate_grad_stream_mismatch_ = enabled;
}
bool Context::isDefaultMobileCPUAllocatorSet() { bool Context::isDefaultMobileCPUAllocatorSet() {
return prev_allocator_ptr_ != nullptr; return prev_allocator_ptr_ != nullptr;
} }

View File

@ -174,12 +174,6 @@ class TORCH_API Context {
static long versionCuDNN() { static long versionCuDNN() {
return detail::getCUDAHooks().versionCuDNN(); return detail::getCUDAHooks().versionCuDNN();
} }
static long versionRuntimeCuDNN() {
return detail::getCUDAHooks().versionRuntimeCuDNN();
}
static long versionCuDNNFrontend() {
return detail::getCUDAHooks().versionCuDNNFrontend();
}
static bool hasCuSOLVER() { static bool hasCuSOLVER() {
return detail::getCUDAHooks().hasCuSOLVER(); return detail::getCUDAHooks().hasCuSOLVER();
} }
@ -410,9 +404,6 @@ class TORCH_API Context {
void setDisplayVmapFallbackWarnings(bool enabled); void setDisplayVmapFallbackWarnings(bool enabled);
bool areVmapFallbackWarningsEnabled() const; bool areVmapFallbackWarningsEnabled() const;
void setWarnOnAccumulateGradStreamMismatch(bool enabled);
bool warnOnAccumulateGradStreamMismatch() const;
bool isDefaultMobileCPUAllocatorSet(); bool isDefaultMobileCPUAllocatorSet();
void setDefaultMobileCPUAllocator(); void setDefaultMobileCPUAllocator();
void unsetDefaultMobileCPUAllocator(); void unsetDefaultMobileCPUAllocator();
@ -503,7 +494,6 @@ class TORCH_API Context {
bool release_original_weights = false; bool release_original_weights = false;
#endif #endif
bool display_vmap_fallback_warnings_ = false; bool display_vmap_fallback_warnings_ = false;
bool warn_on_accumulate_grad_stream_mismatch_ = true;
std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine; std::atomic<at::QEngine> quantized_engine = at::QEngine::NoQEngine;
bool enable_sparse_tensor_invariant_checks = false; bool enable_sparse_tensor_invariant_checks = false;
bool allow_fp16_reduction_cpu = false; bool allow_fp16_reduction_cpu = false;

View File

@ -6,7 +6,6 @@
#include <c10/util/Half.h> #include <c10/util/Half.h>
#include <c10/util/Metaprogramming.h> #include <c10/util/Metaprogramming.h>
#include <c10/util/complex.h> #include <c10/util/complex.h>
#include <torch/headeronly/core/Dispatch.h>
#ifdef __CUDACC__ #ifdef __CUDACC__
#include <cuda.h> // For CUDA_VERSION #include <cuda.h> // For CUDA_VERSION
@ -62,9 +61,12 @@ TORCH_API void record_kernel_function_dtype(std::string name);
} \ } \
} while (0) } while (0)
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \ case enum_type: { \
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__) AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
return __VA_ARGS__(); \
}
#define AT_DISPATCH_CASE(enum_type, ...) \ #define AT_DISPATCH_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__) AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
@ -93,6 +95,14 @@ TORCH_API void record_kernel_function_dtype(std::string name);
return __VA_ARGS__(); \ return __VA_ARGS__(); \
} }
namespace detail {
inline at::ScalarType scalar_type(at::ScalarType s) {
return s;
}
} // namespace detail
// The AT_DISPATCH_* family of macros provides the ability to // The AT_DISPATCH_* family of macros provides the ability to
// conveniently generate specializations of a kernel over all of the // conveniently generate specializations of a kernel over all of the
// dtypes we care about in PyTorch. We call it "dispatch" because // dtypes we care about in PyTorch. We call it "dispatch" because
@ -180,13 +190,25 @@ TORCH_API void record_kernel_function_dtype(std::string name);
// but we're just being safe (and it doesn't hurt.) Note we must // but we're just being safe (and it doesn't hurt.) Note we must
// use it to shut up warnings about unused store. // use it to shut up warnings about unused store.
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \ #define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH_TMPL( \ [&] { \
RECORD_KERNEL_FUNCTION_DTYPE, \ const auto& the_type = TYPE; \
TORCH_CHECK_NOT_IMPLEMENTED, \ constexpr const char* at_dispatch_name = NAME; \
TYPE, \ /* don't use TYPE again in case it is an expensive or side-effect op */ \
NAME, \ at::ScalarType _st = ::detail::scalar_type(the_type); \
__VA_ARGS__) RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK_NOT_IMPLEMENTED( \
false, \
'"', \
at_dispatch_name, \
"\" not implemented for '", \
toString(_st), \
"'"); \
} \
}()
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \ #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \

View File

@ -1,8 +1,3 @@
#pragma once
#include <torch/headeronly/core/Dispatch_v2.h>
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
#include <ATen/Dispatch.h> #include <ATen/Dispatch.h>
// This is a new implementation of the AT_DISPATCH macro family from // This is a new implementation of the AT_DISPATCH macro family from
@ -79,19 +74,41 @@
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly // macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
// relied on GPT4 to help me get it right. // relied on GPT4 to help me get it right.
// Public API macros
// See documentation above // See documentation above
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \ #define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
THO_DISPATCH_V2_TMPL( \ AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
AT_DISPATCH_SWITCH, \
AT_DISPATCH_CASE, \ // This macro lets you pass an arbitrary expression that may contain internal
TYPE, \ // commas to another macro without having the commas causing the expression
NAME, \ // to be interpreted as being multiple arguments
AT_WRAP(BODY), \ #define AT_WRAP(...) __VA_ARGS__
__VA_ARGS__)
#define AT_FLOAT8_TYPES \
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
#define AT_INTEGRAL_TYPES \
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
#define AT_INTEGRAL_TYPES_V2 \
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
// NB: not *actually* all types
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
#define AT_ALL_TYPES_AND_COMPLEX \
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
// Helper macros
// Unused helper macros, kept for BC:
#define AT_AP_VAR(N, T, ...) \ #define AT_AP_VAR(N, T, ...) \
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__)) AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
#define AT_CONCAT_AUX(a, b) a##b
#define AT_EXPAND(X) X
// Ensure we never have too many scalar types for the expansion here to // Ensure we never have too many scalar types for the expansion here to
// support. To bump this, you must regenerate the macros below. // support. To bump this, you must regenerate the macros below.
@ -102,6 +119,12 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
num_args = 60 num_args = 60
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
for i in range(1, num_args+1): for i in range(1, num_args+1):
args = ', '.join(f'_{i}' for i in range(1, i+1)) args = ', '.join(f'_{i}' for i in range(1, i+1))
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)]) cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
@ -112,6 +135,8 @@ for i in range(1, num_args+1):
// Begin generated code // Begin generated code
// clang-format off // clang-format off
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N) #define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) #define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) #define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)

View File

@ -252,13 +252,13 @@ MapAllocator::MapAllocator(WithFd /*unused*/, std::string_view filename, int fd,
if (!(flags_ & ALLOCATOR_MAPPED_FROMFD)) { if (!(flags_ & ALLOCATOR_MAPPED_FROMFD)) {
if (flags_ & ALLOCATOR_MAPPED_SHARED) { if (flags_ & ALLOCATOR_MAPPED_SHARED) {
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition) // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
if ((fd = open(filename_.c_str(), flags, static_cast<mode_t>(0600))) == -1) { if ((fd = open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
TORCH_CHECK(false, "unable to open file <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")"); TORCH_CHECK(false, "unable to open file <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
} }
} else if (flags_ & ALLOCATOR_MAPPED_SHAREDMEM) { } else if (flags_ & ALLOCATOR_MAPPED_SHAREDMEM) {
#ifdef HAVE_SHM_OPEN #ifdef HAVE_SHM_OPEN
// NOLINTNEXTLINE(bugprone-assignment-in-if-condition) // NOLINTNEXTLINE(bugprone-assignment-in-if-condition)
if((fd = shm_open(filename_.c_str(), flags, static_cast<mode_t>(0600))) == -1) { if((fd = shm_open(filename_.c_str(), flags, (mode_t)0600)) == -1) {
TORCH_CHECK(false, "unable to open shared memory object <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")"); TORCH_CHECK(false, "unable to open shared memory object <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")");
} }
#else #else
@ -503,7 +503,7 @@ RefcountedMapAllocator::RefcountedMapAllocator(WithFd /*unused*/, const char *fi
void RefcountedMapAllocator::initializeAlloc() { void RefcountedMapAllocator::initializeAlloc() {
TORCH_CHECK(base_ptr_, "base_ptr_ is null"); TORCH_CHECK(base_ptr_, "base_ptr_ is null");
MapInfo *map_info = static_cast<MapInfo*>(base_ptr_); MapInfo *map_info = (MapInfo*)base_ptr_;
#ifdef _WIN32 #ifdef _WIN32
ReleaseContext* r_ctx = new ReleaseContext; ReleaseContext* r_ctx = new ReleaseContext;
@ -539,7 +539,7 @@ void RefcountedMapAllocator::close() {
} }
#else /* _WIN32 */ #else /* _WIN32 */
MapInfo *info = static_cast<MapInfo*>(data); MapInfo *info = (MapInfo*)(data);
if (--info->refcount == 0) { if (--info->refcount == 0) {
#ifdef HAVE_SHM_UNLINK #ifdef HAVE_SHM_UNLINK
if (shm_unlink(filename_.c_str()) == -1) { if (shm_unlink(filename_.c_str()) == -1) {

View File

@ -862,7 +862,7 @@ void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
shape_[dim] = size; shape_[dim] = size;
view_offsets_[dim] += start; view_offsets_[dim] += start;
for (auto& op : operands_) { for (auto& op : operands_) {
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[dim] * start; op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
} }
if (size == 1 && !is_reduction_) { if (size == 1 && !is_reduction_) {
coalesce_dimensions(); coalesce_dimensions();
@ -873,7 +873,7 @@ void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indic
TORCH_INTERNAL_ASSERT(start_dim <= ndim()); TORCH_INTERNAL_ASSERT(start_dim <= ndim());
for (const auto i : c10::irange(start_dim, ndim())) { for (const auto i : c10::irange(start_dim, ndim())) {
for (auto& op : operands_) { for (auto& op : operands_) {
op.data = (static_cast<char*>(op.data)) + op.stride_bytes[i] * indices[i - start_dim]; op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
} }
shape_[i] = 1; shape_[i] = 1;
} }

View File

@ -41,7 +41,7 @@ inline void serial_for_each(
IntArrayRef strides, IntArrayRef strides,
char** base_ptrs, char** base_ptrs,
size_t ntensors, size_t ntensors,
TensorIteratorBase::loop2d_t loop, typename TensorIteratorBase::loop2d_t loop,
Range range) { Range range) {
const auto ndim = shape.size(); const auto ndim = shape.size();
TORCH_INTERNAL_ASSERT_DEBUG_ONLY( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(

View File

@ -72,16 +72,10 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) {
m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>); m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>);
m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>); m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("rand_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>); m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randn_like.generator", unsupportedRandomOp<const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>); m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.Tensor", unsupportedRandomOp<const Tensor&, const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>); m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.generator", unsupportedRandomOp<const Tensor&, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.Tensor_generator", unsupportedRandomOp<const Tensor&, const Tensor&, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("randint_like.low_generator_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, std::optional<Generator>, TENSOROPTIONS, std::optional<MemoryFormat>>);
m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>); m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>); m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);

View File

@ -190,14 +190,12 @@ class IListRef;
* it to a function (e.g. `ImplT::<dispatch-function>(this_)`). * it to a function (e.g. `ImplT::<dispatch-function>(this_)`).
*/ */
#define TORCH_ILISTREF_UNWRAP(TAG, BODY) \ #define TORCH_ILISTREF_UNWRAP(TAG, BODY) \
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
switch (TAG) { \ switch (TAG) { \
TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \ TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \
break; \ break; \
default: \ default: \
TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \
} \ }
C10_DIAGNOSTIC_POP()
enum class IListRefTag { enum class IListRefTag {
#define DEFINE_TAG(tag, ...) tag, #define DEFINE_TAG(tag, ...) tag,

View File

@ -56,7 +56,7 @@ C10_HOST_DEVICE inline T uniform_int_full_range(V val) {
* in this overloaded version * in this overloaded version
*/ */
template <typename T, typename V> template <typename T, typename V>
C10_HOST_DEVICE inline std::enable_if_t<!std::is_floating_point_v<T>, T>uniform_int(V val) { C10_HOST_DEVICE inline std::enable_if_t<!(std::is_floating_point_v<T>), T>uniform_int(V val) {
if constexpr (std::is_same_v<T, bool>) { if constexpr (std::is_same_v<T, bool>) {
return static_cast<bool>(val & 1); return static_cast<bool>(val & 1);
} else if constexpr (std::is_same_v<T, int64_t>) { } else if constexpr (std::is_same_v<T, int64_t>) {

View File

@ -114,25 +114,25 @@ inline typename remove_symint<T>::type unpackSymInt(T x) {
} }
template <> template <>
inline remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) { inline typename remove_symint<c10::SymInt>::type unpackSymInt(c10::SymInt x) {
return x.guard_int(__FILE__, __LINE__); return x.guard_int(__FILE__, __LINE__);
} }
template <> template <>
inline remove_symint<c10::SymIntArrayRef>::type unpackSymInt( inline typename remove_symint<c10::SymIntArrayRef>::type unpackSymInt(
c10::SymIntArrayRef x) { c10::SymIntArrayRef x) {
return C10_AS_INTARRAYREF_SLOW(x); return C10_AS_INTARRAYREF_SLOW(x);
} }
template <> template <>
inline remove_symint<std::optional<c10::SymInt>>::type unpackSymInt( inline typename remove_symint<std::optional<c10::SymInt>>::type unpackSymInt(
std::optional<c10::SymInt> x) { std::optional<c10::SymInt> x) {
return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__)) return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__))
: std::nullopt; : std::nullopt;
} }
template <> template <>
inline remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt( inline typename remove_symint<at::OptionalSymIntArrayRef>::type unpackSymInt(
at::OptionalSymIntArrayRef x) { at::OptionalSymIntArrayRef x) {
return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x))
: std::nullopt; : std::nullopt;

View File

@ -631,8 +631,8 @@ call_functor_with_args_from_stack_(
Stack* stack, Stack* stack,
std::index_sequence<ivalue_arg_indices...> /*unused*/, std::index_sequence<ivalue_arg_indices...> /*unused*/,
guts::typelist::typelist<ArgTypes...>* /*unused*/) { guts::typelist::typelist<ArgTypes...>* /*unused*/) {
(void)stack; // when sizeof...(ivalue_arg_indices) == 0, this argument would (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would
// be unused and we have to silence the compiler warning. // be unused and we have to silence the compiler warning.
// We're explicitly filtering out DispatchKeySet from the argument list. // We're explicitly filtering out DispatchKeySet from the argument list.
// Some kernels take a DispatchKeySet as their first argument in order to // Some kernels take a DispatchKeySet as their first argument in order to

View File

@ -18,7 +18,6 @@ struct TORCH_API EnumType : public NamedType {
TypePtr value, TypePtr value,
std::vector<EnumNameValue> enum_names_values, std::vector<EnumNameValue> enum_names_values,
std::weak_ptr<::torch::jit::CompilationUnit> cu) { std::weak_ptr<::torch::jit::CompilationUnit> cu) {
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
switch (value->kind()) { switch (value->kind()) {
case TypeKind::IntType: case TypeKind::IntType:
case TypeKind::FloatType: case TypeKind::FloatType:
@ -35,7 +34,6 @@ struct TORCH_API EnumType : public NamedType {
value->str(), value->str(),
"', only int, float and string are supported"); "', only int, float and string are supported");
} }
C10_DIAGNOSTIC_POP()
} }
std::string str() const override { std::string str() const override {

View File

@ -601,8 +601,8 @@ std::ostream& IValue::repr(
double d = v.toDouble(); double d = v.toDouble();
int c = std::fpclassify(d); int c = std::fpclassify(d);
if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) { if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
int64_t i = static_cast<int64_t>(d); int64_t i = int64_t(d);
if (static_cast<double>(i) == d) { if (double(i) == d) {
// -0.0 (signed zero) needs to be parsed as -0. // -0.0 (signed zero) needs to be parsed as -0.
if (i == 0 && std::signbit(d)) { if (i == 0 && std::signbit(d)) {
return out << "-" << i << "."; return out << "-" << i << ".";
@ -799,8 +799,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
double d = v.toDouble(); double d = v.toDouble();
int c = std::fpclassify(d); int c = std::fpclassify(d);
if (c == FP_NORMAL || c == FP_ZERO) { if (c == FP_NORMAL || c == FP_ZERO) {
int64_t i = static_cast<int64_t>(d); int64_t i = int64_t(d);
if (static_cast<double>(i) == d) { if (double(i) == d) {
return out << i << "."; return out << i << ".";
} }
} }

View File

@ -41,7 +41,7 @@ void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten);
inline bool is_contiguous_strides( inline bool is_contiguous_strides(
const IntArrayRef sizes, const IntArrayRef sizes,
const IntArrayRef strides) { const IntArrayRef strides) {
size_t n_dim = sizes.size(); int n_dim = static_cast<int>(sizes.size());
if (n_dim == 0) { if (n_dim == 0) {
return true; return true;
} }
@ -50,7 +50,7 @@ inline bool is_contiguous_strides(
return false; return false;
} }
for (int i = static_cast<int>(n_dim) - 2; i >= 0; i--) { for (int i = n_dim - 2; i >= 0; i--) {
if (strides[i] != strides[i + 1] * sizes[i + 1]) { if (strides[i] != strides[i + 1] * sizes[i + 1]) {
return false; return false;
} }
@ -922,7 +922,6 @@ struct TORCH_API DictType : public SharedType {
if (auto dyn = key->castRaw<DynamicType>()) { if (auto dyn = key->castRaw<DynamicType>()) {
kind = dyn->dynamicKind(); kind = dyn->dynamicKind();
} }
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum")
switch (kind) { switch (kind) {
case TypeKind::AnyType: case TypeKind::AnyType:
case TypeKind::IntType: case TypeKind::IntType:
@ -939,7 +938,6 @@ struct TORCH_API DictType : public SharedType {
key->str(), key->str(),
"', only int, float, complex, Tensor, device and string keys are supported"); "', only int, float, complex, Tensor, device and string keys are supported");
} }
C10_DIAGNOSTIC_POP()
} }
// aligned with the format in FunctionSchema // aligned with the format in FunctionSchema
@ -2373,7 +2371,7 @@ private:
}; };
template<> template<>
inline detail::CastReturnType<NamedType>::type Type::cast<NamedType>() { inline typename detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this()); return std::static_pointer_cast<NamedType>(static_cast<NamedType *>(this)->shared_from_this());
@ -2382,7 +2380,7 @@ inline detail::CastReturnType<NamedType>::type Type::cast<NamedType>() {
} }
template<> template<>
inline detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const { inline typename detail::CastConstReturnType<NamedType>::type Type::cast<NamedType>() const {
if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType ||
kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) {
return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this()); return std::static_pointer_cast<const NamedType>(static_cast<const NamedType *>(this)->shared_from_this());

View File

@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
auto vals = svreinterpret_u16_bf16(values); auto vals = svreinterpret_u16_bf16(values);
vals = sveor_u16_x(ptrue, vals, mask); vals = sveor_u16_x(ptrue, vals, mask);
return svreinterpret_bf16_u16(vals); return svreinterpret_bf16_u16(vals);
} };
Vectorized<BFloat16> round() const; Vectorized<BFloat16> round() const;
Vectorized<BFloat16> tan() const; Vectorized<BFloat16> tan() const;
Vectorized<BFloat16> tanh() const; Vectorized<BFloat16> tanh() const;
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
return convert_float_bfloat16(v1, v2); \ return convert_float_bfloat16(v1, v2); \
} }
DEFINE_BF16_FUNC_VIA_FLOAT(isnan) DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
DEFINE_BF16_FUNC_VIA_FLOAT(angle) DEFINE_BF16_FUNC_VIA_FLOAT(angle);
DEFINE_BF16_FUNC_VIA_FLOAT(acos) DEFINE_BF16_FUNC_VIA_FLOAT(acos);
DEFINE_BF16_FUNC_VIA_FLOAT(acosh) DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
DEFINE_BF16_FUNC_VIA_FLOAT(asin) DEFINE_BF16_FUNC_VIA_FLOAT(asin);
DEFINE_BF16_FUNC_VIA_FLOAT(atan) DEFINE_BF16_FUNC_VIA_FLOAT(atan);
DEFINE_BF16_FUNC_VIA_FLOAT(atanh) DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2) DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign) DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
DEFINE_BF16_FUNC_VIA_FLOAT(erf) DEFINE_BF16_FUNC_VIA_FLOAT(erf);
DEFINE_BF16_FUNC_VIA_FLOAT(erfc) DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
DEFINE_BF16_FUNC_VIA_FLOAT(exp) DEFINE_BF16_FUNC_VIA_FLOAT(exp);
DEFINE_BF16_FUNC_VIA_FLOAT(exp2) DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
DEFINE_BF16_FUNC_VIA_FLOAT(expm1) DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod) DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot) DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
DEFINE_BF16_FUNC_VIA_FLOAT(i0) DEFINE_BF16_FUNC_VIA_FLOAT(i0);
DEFINE_BF16_FUNC_VIA_FLOAT(i0e) DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
DEFINE_BF16_FUNC_VIA_FLOAT(digamma) DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma) DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac) DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter) DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
DEFINE_BF16_FUNC_VIA_FLOAT(log) DEFINE_BF16_FUNC_VIA_FLOAT(log);
DEFINE_BF16_FUNC_VIA_FLOAT(log2) DEFINE_BF16_FUNC_VIA_FLOAT(log2);
DEFINE_BF16_FUNC_VIA_FLOAT(log10) DEFINE_BF16_FUNC_VIA_FLOAT(log10);
DEFINE_BF16_FUNC_VIA_FLOAT(log1p) DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
DEFINE_BF16_FUNC_VIA_FLOAT(sin) DEFINE_BF16_FUNC_VIA_FLOAT(sin);
DEFINE_BF16_FUNC_VIA_FLOAT(sinh) DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
DEFINE_BF16_FUNC_VIA_FLOAT(cos) DEFINE_BF16_FUNC_VIA_FLOAT(cos);
DEFINE_BF16_FUNC_VIA_FLOAT(cosh) DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
DEFINE_BF16_FUNC_VIA_FLOAT(ceil) DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
DEFINE_BF16_FUNC_VIA_FLOAT(floor) DEFINE_BF16_FUNC_VIA_FLOAT(floor);
DEFINE_BF16_FUNC_VIA_FLOAT(round) DEFINE_BF16_FUNC_VIA_FLOAT(round);
DEFINE_BF16_FUNC_VIA_FLOAT(tan) DEFINE_BF16_FUNC_VIA_FLOAT(tan);
DEFINE_BF16_FUNC_VIA_FLOAT(tanh) DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
DEFINE_BF16_FUNC_VIA_FLOAT(trunc) DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma) DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt) DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal) DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt) DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow) DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==( Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
const Vectorized<BFloat16>& other) const { const Vectorized<BFloat16>& other) const {

View File

@ -19,13 +19,6 @@ inline namespace CPU_CAPABILITY {
#error "Big endian is not supported." #error "Big endian is not supported."
#endif #endif
// GCC does not properly optimize bf16 operators
#if defined(__ARM_FEATURE_BF16) && (__clang_major__ >= 19)
#define BF16_ARITHMETIC_SUPPORTED() 1
#else
#define BF16_ARITHMETIC_SUPPORTED() 0
#endif
// Unlike the float16_t family of types, bfloat16_t is not available // Unlike the float16_t family of types, bfloat16_t is not available
// when we're not targeting bfloat16 hardware support on some // when we're not targeting bfloat16 hardware support on some
// platforms (but not Mac, so we have to be careful not to shadow the // platforms (but not Mac, so we have to be careful not to shadow the
@ -359,72 +352,18 @@ class Vectorized<c10::BFloat16> : public Vectorized16<
other, &Vectorized<float>::name); \ other, &Vectorized<float>::name); \
} }
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
Vectorized frac() const; Vectorized frac() const;
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt)
#ifdef __ARM_FEATURE_BF16
// Flip sign bit
Vectorized<c10::BFloat16> neg() const {
return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768));
}
// Fast reciprocal is fine because we are truncating results
Vectorized<c10::BFloat16> reciprocal() const {
auto x = vcvtq_low_f32_bf16(values);
auto y = vcvtq_high_f32_bf16(values);
x = vrecpeq_f32(x);
y = vrecpeq_f32(y);
return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y);
}
// Clearing the sign bit
Vectorized<c10::BFloat16> abs() const {
return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF);
}
#else
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg)
DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal)
#endif
// These functions are optimized on clang-21+
#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21)
Vectorized<c10::BFloat16> operator==(
const Vectorized<c10::BFloat16>& other) const {
return values == other.values;
}
Vectorized<c10::BFloat16> operator!=(
const Vectorized<c10::BFloat16>& other) const {
return values != other.values;
}
Vectorized<c10::BFloat16> operator<(
const Vectorized<c10::BFloat16>& other) const {
return values < other.values;
}
Vectorized<c10::BFloat16> operator<=(
const Vectorized<c10::BFloat16>& other) const {
return values <= other.values;
}
Vectorized<c10::BFloat16> operator>(
const Vectorized<c10::BFloat16>& other) const {
return values > other.values;
}
Vectorized<c10::BFloat16> operator>=(
const Vectorized<c10::BFloat16>& other) const {
return values >= other.values;
}
#else
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<=)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>)
DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator>=)
#endif
#undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD #undef DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
#undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD #undef DEFINE_BINARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD
@ -473,52 +412,28 @@ template <>
Vectorized<c10::BFloat16> inline operator+( Vectorized<c10::BFloat16> inline operator+(
const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) { const Vectorized<c10::BFloat16>& b) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x + y;
#else
return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b); return binary_operator_via_float(std::plus<Vectorized<float>>(), a, b);
#endif
} }
template <> template <>
Vectorized<c10::BFloat16> inline operator-( Vectorized<c10::BFloat16> inline operator-(
const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) { const Vectorized<c10::BFloat16>& b) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x - y;
#else
return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b); return binary_operator_via_float(std::minus<Vectorized<float>>(), a, b);
#endif
} }
template <> template <>
Vectorized<c10::BFloat16> inline operator*( Vectorized<c10::BFloat16> inline operator*(
const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) { const Vectorized<c10::BFloat16>& b) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x * y;
#else
return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b); return binary_operator_via_float(std::multiplies<Vectorized<float>>(), a, b);
#endif
} }
template <> template <>
Vectorized<c10::BFloat16> inline operator/( Vectorized<c10::BFloat16> inline operator/(
const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b) { const Vectorized<c10::BFloat16>& b) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
return x / y;
#else
return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b); return binary_operator_via_float(std::divides<Vectorized<float>>(), a, b);
#endif
} }
// frac. Implement this here so we can use subtraction // frac. Implement this here so we can use subtraction
@ -629,19 +544,12 @@ Vectorized<c10::BFloat16> inline fmadd(
const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b, const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) { const Vectorized<c10::BFloat16>& c) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y + z;
#else
// NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also, // NOTE [BF16 FMA]: There isn't an FMA that accumulates into BF16! Also,
// vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered // vbfmlalbq_f32 and vbfmlaltq_f32 take the even and odd-numbered
// elements, not the bottom and top half, so they don't seem // elements, not the bottom and top half, so they don't seem
// particularly useful here. Ideally we would include dot product in // particularly useful here. Ideally we would include dot product in
// the Vectorized interface... // the Vectorized interface...
return a * b + c; return a * b + c;
#endif
} }
template <> template <>
@ -649,15 +557,8 @@ Vectorized<c10::BFloat16> inline fnmadd(
const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b, const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) { const Vectorized<c10::BFloat16>& c) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y + z;
#else
// See NOTE [BF16 FMA] above. // See NOTE [BF16 FMA] above.
return -a * b + c; return -a * b + c;
#endif
} }
template <> template <>
@ -665,15 +566,8 @@ Vectorized<c10::BFloat16> inline fmsub(
const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b, const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) { const Vectorized<c10::BFloat16>& c) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return x * y - z;
#else
// See NOTE [BF16 FMA] above. // See NOTE [BF16 FMA] above.
return a * b - c; return a * b - c;
#endif
} }
template <> template <>
@ -681,15 +575,8 @@ Vectorized<c10::BFloat16> inline fnmsub(
const Vectorized<c10::BFloat16>& a, const Vectorized<c10::BFloat16>& a,
const Vectorized<c10::BFloat16>& b, const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) { const Vectorized<c10::BFloat16>& c) {
#if BF16_ARITHMETIC_SUPPORTED()
bfloat16x8_t x = a;
bfloat16x8_t y = b;
bfloat16x8_t z = c;
return (-x) * y - z;
#else
// See NOTE [BF16 FMA] above. // See NOTE [BF16 FMA] above.
return -a * b - c; return -a * b - c;
#endif
} }
#endif // !defined(C10_MOBILE) && defined(__aarch64__) #endif // !defined(C10_MOBILE) && defined(__aarch64__)

View File

@ -6,9 +6,9 @@ namespace at::vec {
inline namespace CPU_CAPABILITY { inline namespace CPU_CAPABILITY {
#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256))
// Enable auto-vectorization for clang-17+ // Enable auto-vectorization for GCC-13+ and clang-17+
// GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 // GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001
#if defined(__clang__) && (__clang_major__ >= 17) #if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17))
template <typename from_type, typename to_type> template <typename from_type, typename to_type>
inline void convertImpl( inline void convertImpl(
@ -191,37 +191,22 @@ inline void convert(const at::Half* src, bool* dst, int64_t n) {
} }
#endif #endif
#ifdef __ARM_FEATURE_BF16
template <typename to_type> CONVERT_TEMPLATE(bfloat16_t, uint8_t)
inline void convertFromBf16Impl( CONVERT_TEMPLATE(bfloat16_t, int8_t)
const c10::BFloat16* __restrict src, CONVERT_TEMPLATE(bfloat16_t, int16_t)
to_type* __restrict dst, CONVERT_TEMPLATE(bfloat16_t, int32_t)
int64_t n) { CONVERT_TEMPLATE(bfloat16_t, int64_t)
const uint16_t* srcPtr = reinterpret_cast<const uint16_t*>(src); CONVERT_TEMPLATE(bfloat16_t, bfloat16_t)
uint64_t len = static_cast<uint64_t>(n); CONVERT_TEMPLATE(bfloat16_t, float)
for (uint64_t i = 0; i < len; i++) { CONVERT_TEMPLATE(bfloat16_t, double)
uint32_t tmp = static_cast<uint32_t>(srcPtr[i]) << 16; CONVERT_TEMPLATE(uint8_t, bfloat16_t)
float tmpF; CONVERT_TEMPLATE(int8_t, bfloat16_t)
__builtin_memcpy(&tmpF, &tmp, sizeof(float)); CONVERT_TEMPLATE(int16_t, bfloat16_t)
dst[i] = static_cast<to_type>(tmpF); CONVERT_TEMPLATE(int32_t, bfloat16_t)
} CONVERT_TEMPLATE(int64_t, bfloat16_t)
} CONVERT_TEMPLATE(float, bfloat16_t)
#define CONVERT_FROM_BF16_TEMPLATE(to_type) \ CONVERT_TEMPLATE(double, bfloat16_t)
template <> \
inline void convert(const c10::BFloat16* src, to_type* dst, int64_t n) { \
return convertFromBf16Impl<to_type>(src, dst, n); \
}
CONVERT_FROM_BF16_TEMPLATE(uint8_t)
CONVERT_FROM_BF16_TEMPLATE(int8_t)
CONVERT_FROM_BF16_TEMPLATE(int16_t)
CONVERT_FROM_BF16_TEMPLATE(int32_t)
CONVERT_FROM_BF16_TEMPLATE(int64_t)
CONVERT_FROM_BF16_TEMPLATE(float)
CONVERT_FROM_BF16_TEMPLATE(double)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
CONVERT_FROM_BF16_TEMPLATE(float16_t)
#endif
inline void convertBoolToBfloat16Impl( inline void convertBoolToBfloat16Impl(
const bool* __restrict src, const bool* __restrict src,
@ -262,6 +247,8 @@ inline void convert(const c10::BFloat16* src, bool* dst, int64_t n) {
#endif #endif
#endif
template <typename src_t> template <typename src_t>
struct VecConvert< struct VecConvert<
float, float,

View File

@ -309,7 +309,7 @@ class Vectorized<float> {
DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1)
// Implementation copied from Arm Optimized Routine // Implementation copied from Arm Optimized Routine
// https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c
inline Vectorized<float> vexpq_f32_u20() const { Vectorized<float> exp_u20() const {
// bail out to sleef if it's a special case: // bail out to sleef if it's a special case:
// i.e. there's an input s.t. |input| > 87.3.... // i.e. there's an input s.t. |input| > 87.3....
const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f);
@ -348,9 +348,6 @@ class Vectorized<float> {
return vfmaq_f32(scale, poly, scale); return vfmaq_f32(scale, poly, scale);
} }
Vectorized<float> exp_u20() const {
return vexpq_f32_u20();
}
Vectorized<float> fexp_u20() const { Vectorized<float> fexp_u20() const {
return exp_u20(); return exp_u20();
} }
@ -637,7 +634,7 @@ inline Vectorized<float> Vectorized<float>::erf() const {
// - exp(- x * x) // - exp(- x * x)
auto pow_2 = (*this) * (*this); auto pow_2 = (*this) * (*this);
auto neg_pow_2 = pow_2 ^ neg_zero_vec; auto neg_pow_2 = pow_2 ^ neg_zero_vec;
auto tmp4 = neg_pow_2.vexpq_f32_u20(); auto tmp4 = neg_pow_2.exp();
auto tmp5 = tmp4 ^ neg_zero_vec; auto tmp5 = tmp4 ^ neg_zero_vec;
// erf(x) = sign(x) * (1 - r * t * exp(- x * x)) // erf(x) = sign(x) * (1 - r * t * exp(- x * x))
auto tmp6 = t * tmp5; auto tmp6 = t * tmp5;

View File

@ -514,7 +514,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>; using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>; using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
using value_type = c10::qint8::underlying; using value_type = typename c10::qint8::underlying;
public: public:
using Vectorizedqi::Vectorizedqi; using Vectorizedqi::Vectorizedqi;
@ -727,7 +727,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>; using float_vec_return_type = std::array<Vectorized<float>, kFloatNumVecs>;
using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>; using int_vec_return_type = std::array<Vectorized<c10::qint32>, kIntNumVecs>;
using value_type = c10::quint8::underlying; using value_type = typename c10::quint8::underlying;
public: public:
using Vectorizedqi::Vectorizedqi; using Vectorizedqi::Vectorizedqi;

View File

@ -567,7 +567,7 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
using float_vec_return_type = std::array<Vectorized<float>, 4>; using float_vec_return_type = std::array<Vectorized<float>, 4>;
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>; using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
using value_type = c10::qint8::underlying; using value_type = typename c10::qint8::underlying;
public: public:
using Vectorizedqi::Vectorizedqi; using Vectorizedqi::Vectorizedqi;
@ -804,7 +804,7 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
using float_vec_return_type = std::array<Vectorized<float>, 4>; using float_vec_return_type = std::array<Vectorized<float>, 4>;
using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>; using int_vec_return_type = std::array<Vectorized<c10::qint32>, 4>;
using value_type = c10::quint8::underlying; using value_type = typename c10::quint8::underlying;
public: public:
using Vectorizedqi::Vectorizedqi; using Vectorizedqi::Vectorizedqi;

View File

@ -672,7 +672,7 @@ struct Vectorized {
return map(std::sqrt); return map(std::sqrt);
} }
Vectorized<T> reciprocal() const { Vectorized<T> reciprocal() const {
return map([](T x) { return (T)1 / x; }); return map([](T x) { return (T)(1) / x; });
} }
Vectorized<T> rsqrt() const { Vectorized<T> rsqrt() const {
return map([](T x) { return (T)1 / std::sqrt(x); }); return map([](T x) { return (T)1 / std::sqrt(x); });

View File

@ -46,7 +46,7 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
map( map(
[](const Vectorized<scalar_t>& x) { [](const Vectorized<scalar_t>& x) {
return Vectorized<scalar_t>((scalar_t)1) / x.sqrt(); return Vectorized<scalar_t>((scalar_t)(1)) / x.sqrt();
}, },
out + begin, out + begin,
in + begin, in + begin,

View File

@ -388,7 +388,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
#ifndef USE_ROCM #ifndef USE_ROCM
at::Half halpha; at::Half halpha;
at::Half hbeta; at::Half hbeta;
uint32_t mask = -1;
#endif #endif
void * alpha_ptr = &alpha; void * alpha_ptr = &alpha;
void * beta_ptr = &beta; void * beta_ptr = &beta;
@ -428,7 +427,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS(); auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction != if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
mask = uint32_t mask =
fp16_reduction == fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -445,7 +444,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS(); auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction != if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
mask = uint32_t mask =
bf16_reduction == bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -512,41 +511,17 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulHeuristicResult_t heuristicResult = {}; cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0; int returnedResult = 0;
// on Blackwell+, we fake a n > 1 matmul when querying heuristics TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance ltHandle,
#ifndef USE_ROCM computeDesc.descriptor(),
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10; Adesc.descriptor(),
#else Bdesc.descriptor(),
const bool lie_to_cublaslt = false; Cdesc.descriptor(),
#endif Cdesc.descriptor(),
if (lie_to_cublaslt) { preference.descriptor(),
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T); 1,
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc); &heuristicResult,
&returnedResult));
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
FakeBdesc.descriptor(),
FakeCdesc.descriptor(),
FakeCdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
} else {
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
}
if (returnedResult == 0) { if (returnedResult == 0) {
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED; cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
} }
@ -1597,7 +1572,7 @@ bool gemm_and_bias(
} }
using opmath_t = at::opmath_type<Dtype>; using opmath_t = at::opmath_type<Dtype>;
opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr opmath_t beta_val = 0; // bias is added in epilogue
cudaDataType_t abType = CUDA_R_32F; cudaDataType_t abType = CUDA_R_32F;
cudaDataType_t cType = CUDA_R_32F; cudaDataType_t cType = CUDA_R_32F;
@ -1686,22 +1661,15 @@ bool gemm_and_bias(
_syncCurrentWithCarveoutStream(stream, true); _syncCurrentWithCarveoutStream(stream, true);
} }
#endif #endif
const auto epilogue = [&]() -> cublasLtEpilogue_t { cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
// The cuBLAS documentation indicates that if (activation == GEMMAndBiasActivationEpilogue::RELU) {
// *_<ACTIVATION>_BIAS = *_<ACTIVATION>, epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
// but we keep it verbose here for clarity. } else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
switch (activation) { epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
case GEMMAndBiasActivationEpilogue::RELU: }
return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU;
case GEMMAndBiasActivationEpilogue::GELU:
return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU;
default:
return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT;
}
}();
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
if (bias) { if (bias != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
} }

View File

@ -194,8 +194,8 @@ void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) {
void CUDAGeneratorState::capture_prologue() { void CUDAGeneratorState::capture_prologue() {
capturing_ = true; capturing_ = true;
offset_intragraph_ = 0; offset_intragraph_ = 0;
seed_extragraph_.fill_(static_cast<int64_t>(seed_)); seed_extragraph_.fill_(int64_t(seed_));
offset_extragraph_.fill_(0); offset_extragraph_.fill_(int64_t(0));
} }
/** /**
@ -216,8 +216,8 @@ void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) {
at::cuda::assertNotCapturing( at::cuda::assertNotCapturing(
"Cannot prepare for replay during capturing stage."); "Cannot prepare for replay during capturing stage.");
if (wholegraph_increment) { if (wholegraph_increment) {
seed_extragraph_.fill_(static_cast<int64_t>(seed_)); seed_extragraph_.fill_(int64_t(seed_));
offset_extragraph_.fill_(static_cast<int64_t>(philox_offset_per_thread_)); offset_extragraph_.fill_(int64_t(philox_offset_per_thread_));
// Applies the total increment achieved during previous captures to update the // Applies the total increment achieved during previous captures to update the
// offset. // offset.
increase(wholegraph_increment); increase(wholegraph_increment);
@ -329,7 +329,7 @@ c10::intrusive_ptr<c10::TensorImpl> CUDAGeneratorImpl::get_state() const {
constexpr size_t offset_size = sizeof(int64_t); constexpr size_t offset_size = sizeof(int64_t);
constexpr size_t total_size = seed_size + offset_size; constexpr size_t total_size = seed_size + offset_size;
auto state_tensor = at::detail::empty_cpu({static_cast<int64_t>(total_size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt);
auto rng_state = state_tensor.data_ptr<uint8_t>(); auto rng_state = state_tensor.data_ptr<uint8_t>();
auto current_seed = this->current_seed(); auto current_seed = this->current_seed();
auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t> auto offset = static_cast<int64_t>(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic<int64_t>

View File

@ -1,90 +1,78 @@
#include <ATen/cuda/CUDAGreenContext.h> #include <ATen/cuda/CUDAGreenContext.h>
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
#include <c10/cuda/driver_api.h>
#include <stdexcept>
#include <vector>
#define HAS_CUDA_GREEN_CONTEXT() 1
#else
#define HAS_CUDA_GREEN_CONTEXT() 0
// Suppress unsued private field warnings as this class is not supposed to be called
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field")
#endif
namespace at::cuda { namespace at::cuda {
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#if CUDA_HAS_GREEN_CONTEXT
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { cudaFree(0);
#if HAS_CUDA_GREEN_CONTEXT() }
int driver_version;
C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version));
TORCH_CHECK(
driver_version >= 12080, "cuda driver too old to use green context!");
CUcontext pctx = nullptr;
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx));
if (C10_UNLIKELY(!pctx)) {
TORCH_WARN(
"Attempted to create a green context but"
" there was no primary context! Creating a primary context...");
cudaFree(0); CUdevice device;
} device_id_ = device_id;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id));
CUdevice device; // Get device resources
device_id_ = device_id; CUdevResource device_resource;
C10_CUDA_DRIVER_CHECK( C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_(
c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); device, &device_resource, CU_DEV_RESOURCE_TYPE_SM));
// Get device resources // Split resources
CUdevResource device_resource; std::vector<CUdevResource> result(1);
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( auto result_data = result.data();
device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); unsigned int nb_groups = 1;
CUdevResource remaining;
// Split resources C10_CUDA_DRIVER_CHECK(
std::vector<CUdevResource> result(1); c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
auto result_data = result.data(); result_data,
unsigned int nb_groups = 1; &nb_groups,
CUdevResource remaining; &device_resource,
&remaining,
0, // default flags
num_sms));
C10_CUDA_DRIVER_CHECK( TORCH_CHECK(nb_groups == 1, "Failed to create single resource group");
c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_(
result_data,
&nb_groups,
&device_resource,
&remaining,
0, // default flags
num_sms));
TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); // Generate resource descriptor
CUdevResourceDesc desc;
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_(
&desc, result_data, 1));
// Generate resource descriptor // Create green context
CUdevResourceDesc desc; // CU_GREEN_CTX_DEFAULT_STREAM is required per docs:
C10_CUDA_DRIVER_CHECK( // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html
c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_(
&desc, result_data, 1)); &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Create green context // Convert to regular context
// CU_GREEN_CTX_DEFAULT_STREAM is required per docs: C10_CUDA_DRIVER_CHECK(
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
&green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM));
// Convert to regular context
C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_));
TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!");
#else #else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif #endif
} }
std::unique_ptr<GreenContext> GreenContext::create( std::unique_ptr<GreenContext> GreenContext::create(
uint32_t num_sms, uint32_t num_sms,
std::optional<uint32_t> device_id) { std::optional<uint32_t> device_id) {
#if HAS_CUDA_GREEN_CONTEXT() #if CUDA_HAS_GREEN_CONTEXT
if (!device_id.has_value()) { if (!device_id.has_value()) {
device_id = at::cuda::current_device(); device_id = at::cuda::current_device();
} }
return std::unique_ptr<GreenContext>(new GreenContext(device_id.value(), num_sms)); return std::make_unique<GreenContext>(device_id.value(), num_sms);
#else #else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif #endif
@ -92,7 +80,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
// Implement move operations // Implement move operations
GreenContext::GreenContext(GreenContext&& other) noexcept{ GreenContext::GreenContext(GreenContext&& other) noexcept{
#if HAS_CUDA_GREEN_CONTEXT() #if CUDA_HAS_GREEN_CONTEXT
device_id_ = std::exchange(other.device_id_, -1); device_id_ = std::exchange(other.device_id_, -1);
green_ctx_ = std::exchange(other.green_ctx_, nullptr); green_ctx_ = std::exchange(other.green_ctx_, nullptr);
context_ = std::exchange(other.context_, nullptr); context_ = std::exchange(other.context_, nullptr);
@ -103,7 +91,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
} }
GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{ GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{
#if HAS_CUDA_GREEN_CONTEXT() #if CUDA_HAS_GREEN_CONTEXT
if (this != &other) { if (this != &other) {
// Clean up current resources // Clean up current resources
if (green_ctx_) { if (green_ctx_) {
@ -132,7 +120,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
} }
GreenContext::~GreenContext() noexcept{ GreenContext::~GreenContext() noexcept{
#if HAS_CUDA_GREEN_CONTEXT() #if CUDA_HAS_GREEN_CONTEXT
C10_CUDA_DRIVER_CHECK( C10_CUDA_DRIVER_CHECK(
c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_));
#else #else
@ -140,9 +128,25 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
#endif #endif
} }
// Get the underlying CUDA context
CUcontext GreenContext::getContext() const {
#if CUDA_HAS_GREEN_CONTEXT
return context_;
#else
TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!");
#endif
}
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx GreenContext::getGreenContext() const {
return green_ctx_;
}
#endif
// Make this context current // Make this context current
void GreenContext::setContext() { void GreenContext::setContext() {
#if HAS_CUDA_GREEN_CONTEXT() #if CUDA_HAS_GREEN_CONTEXT
auto current_stream = c10::cuda::getCurrentCUDAStream(); auto current_stream = c10::cuda::getCurrentCUDAStream();
parent_stream_ = current_stream.stream(); parent_stream_ = current_stream.stream();
@ -171,7 +175,7 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) {
} }
void GreenContext::popContext() { void GreenContext::popContext() {
#if HAS_CUDA_GREEN_CONTEXT() #if CUDA_HAS_GREEN_CONTEXT
// see above note about stream being hardcoded to the default stream // see above note about stream being hardcoded to the default stream
at::cuda::CUDAEvent ev; at::cuda::CUDAEvent ev;
ev.record(c10::cuda::getCurrentCUDAStream()); ev.record(c10::cuda::getCurrentCUDAStream());

View File

@ -1,38 +1,53 @@
#pragma once #pragma once
#include <ATen/cuda/CUDAEvent.h> #include <ATen/cuda/CUDAEvent.h>
#include <cuda.h>
// Forward declare green context as opaque ptr #if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
typedef struct CUgreenCtx_st* CUgreenCtx; #include <c10/cuda/driver_api.h>
#include <cuda.h>
#include <memory>
#include <stdexcept>
#include <vector>
#define CUDA_HAS_GREEN_CONTEXT 1
#else
#define CUDA_HAS_GREEN_CONTEXT 0
#endif
namespace at::cuda { namespace at::cuda {
class TORCH_CUDA_CPP_API GreenContext { class TORCH_CUDA_CPP_API GreenContext {
public: public:
// Green context creation GreenContext(uint32_t device_id, uint32_t num_sms);
static std::unique_ptr<GreenContext> create(
uint32_t num_sms, static std::unique_ptr<GreenContext> create(uint32_t num_sms, std::optional<uint32_t> device_id);
std::optional<uint32_t> device_id);
~GreenContext() noexcept;
// Delete copy constructor and assignment // Delete copy constructor and assignment
GreenContext(const GreenContext&) = delete; GreenContext(const GreenContext&) = delete;
GreenContext& operator=(const GreenContext&) = delete; GreenContext& operator=(const GreenContext&) = delete;
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
~GreenContext() noexcept;
// Get the underlying CUDA context
CUcontext getContext() const;
// Get the underlying green context
#if CUDA_HAS_GREEN_CONTEXT
CUgreenCtx getGreenContext() const;
#endif
// Make this context current // Make this context current
void setContext(); void setContext();
void popContext(); void popContext();
private: private:
GreenContext(uint32_t device_id, uint32_t num_sms); #if CUDA_HAS_GREEN_CONTEXT
// Implement move operations
GreenContext(GreenContext&& other) noexcept;
GreenContext& operator=(GreenContext&& other) noexcept;
int32_t device_id_ = -1; int32_t device_id_ = -1;
CUgreenCtx green_ctx_ = nullptr; CUgreenCtx green_ctx_ = nullptr;
CUcontext context_ = nullptr; CUcontext context_ = nullptr;
cudaStream_t parent_stream_ = nullptr; cudaStream_t parent_stream_ = nullptr;
#endif
}; };
} // namespace at::cuda } // namespace at::cuda

View File

@ -7,6 +7,17 @@
#endif #endif
#if defined(USE_ROCM)
// hipSparse const API added in v2.4.0
#if HIPSPARSE_VERSION >= 200400
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#else
#define AT_USE_HIPSPARSE_GENERIC_API() 1
#endif
#else // USE_ROCM
#define AT_USE_HIPSPARSE_GENERIC_API() 0
#endif // USE_ROCM
// cuSparse Generic API spsv function was added in CUDA 11.3.0 // cuSparse Generic API spsv function was added in CUDA 11.3.0
#if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500) #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
#define AT_USE_CUSPARSE_GENERIC_SPSV() 1 #define AT_USE_CUSPARSE_GENERIC_SPSV() 1

View File

@ -155,8 +155,8 @@ size_t parseChosenWorkspaceSize() {
while (next != end) { while (next != end) {
std::smatch match = *next; std::smatch match = *next;
TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)"); TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
size_t curr_size = std::stoull(match.str(1)); size_t curr_size = (size_t) std::stoi(match.str(1));
size_t count = std::stoull(match.str(2)); size_t count = (size_t) std::stoi(match.str(2));
total_size += curr_size * 1024 * count; total_size += curr_size * 1024 * count;
next++; next++;
} }

View File

@ -55,14 +55,6 @@ struct numeric_limits<int8_t> {
static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; }
}; };
template <>
struct numeric_limits<uint16_t> {
static inline __host__ __device__ uint16_t lowest() { return 0; }
static inline __host__ __device__ uint16_t max() { return UINT16_MAX; }
static inline __host__ __device__ uint16_t lower_bound() { return 0; }
static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; }
};
template <> template <>
struct numeric_limits<int16_t> { struct numeric_limits<int16_t> {
static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
@ -71,14 +63,6 @@ struct numeric_limits<int16_t> {
static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; }
}; };
template <>
struct numeric_limits<uint32_t> {
static inline __host__ __device__ uint32_t lowest() { return 0; }
static inline __host__ __device__ uint32_t max() { return UINT32_MAX; }
static inline __host__ __device__ uint32_t lower_bound() { return 0; }
static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; }
};
template <> template <>
struct numeric_limits<int32_t> { struct numeric_limits<int32_t> {
static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
@ -87,21 +71,6 @@ struct numeric_limits<int32_t> {
static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; }
}; };
template <>
struct numeric_limits<uint64_t> {
#ifdef _MSC_VER
static inline __host__ __device__ uint64_t lowest() { return 0; }
static inline __host__ __device__ uint64_t max() { return _UI64_MAX; }
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; }
#else
static inline __host__ __device__ uint64_t lowest() { return 0; }
static inline __host__ __device__ uint64_t max() { return UINT64_MAX; }
static inline __host__ __device__ uint64_t lower_bound() { return 0; }
static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; }
#endif
};
template <> template <>
struct numeric_limits<int64_t> { struct numeric_limits<int64_t> {
#ifdef _MSC_VER #ifdef _MSC_VER

View File

@ -24,13 +24,7 @@ namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy // radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting // the data, so we can save template instantiations by reinterpreting
// it as an opaque type. // it as an opaque type.
// We use native integer types for 1/2/4/8-byte values to reduce
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
template <int N> struct alignas(N) OpaqueType { char data[N]; }; template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
template<typename key_t, int value_size> template<typename key_t, int value_size>
void radix_sort_pairs_impl( void radix_sort_pairs_impl(

View File

@ -2,6 +2,8 @@
#include <ATen/Tensor.h> #include <ATen/Tensor.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <mutex>
namespace at { namespace at {
namespace cuda { namespace cuda {
namespace detail { namespace detail {
@ -10,36 +12,39 @@ __device__ __constant__ float cublas_one_device;
__device__ __constant__ float cublas_zero_device; __device__ __constant__ float cublas_zero_device;
float *get_cublas_device_one() { float *get_cublas_device_one() {
static float *ptr = nullptr; static c10::once_flag init_flag;
static auto init_flag = [&]() {
c10::call_once(init_flag, []() {
const float one = 1.f; const float one = 1.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float))); AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float)));
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device)); });
return true;
}();
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_one_device));
return ptr; return ptr;
} }
float *get_cublas_device_zero() { float *get_cublas_device_zero() {
static float *ptr = nullptr; static c10::once_flag init_flag;
static auto init_flag = [&]() {
c10::call_once(init_flag, []() {
const float zero = 0.f; const float zero = 0.f;
AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float))); AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float)));
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device)); });
return true;
}();
float *ptr;
AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast<void**>(&ptr), cublas_zero_device));
return ptr; return ptr;
} }
float *get_user_alpha_ptr() { float *get_user_alpha_ptr() {
static float *alpha_ptr; static float *alpha_ptr;
static bool init_flag [[maybe_unused]] = []() { static c10::once_flag init_flag;
c10::call_once(init_flag, []() {
AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float))); AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float)));
return true; });
}();
return alpha_ptr; return alpha_ptr;
} }

View File

@ -21,7 +21,6 @@
#if AT_CUDNN_ENABLED() #if AT_CUDNN_ENABLED()
#include <ATen/cudnn/cudnn-wrapper.h> #include <ATen/cudnn/cudnn-wrapper.h>
#include <cudnn_frontend.h>
#endif #endif
#if AT_MAGMA_ENABLED() #if AT_MAGMA_ENABLED()
@ -352,26 +351,6 @@ long CUDAHooks::versionCuDNN() const {
#endif #endif
} }
long CUDAHooks::versionRuntimeCuDNN() const {
#if AT_CUDNN_ENABLED()
#ifndef USE_STATIC_CUDNN
return cudnnGetVersion();
#else
return CUDNN_VERSION;
#endif
#else
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionCuDNNFrontend() const {
#if AT_CUDNN_ENABLED()
return CUDNN_FRONTEND_VERSION;
#else
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionMIOpen() const { long CUDAHooks::versionMIOpen() const {
#if AT_ROCM_ENABLED() #if AT_ROCM_ENABLED()
return MIOPEN_VERSION_MAJOR * 10000 + return MIOPEN_VERSION_MAJOR * 10000 +

View File

@ -49,8 +49,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCUDART() const override; bool hasCUDART() const override;
long versionCUDART() const override; long versionCUDART() const override;
long versionCuDNN() const override; long versionCuDNN() const override;
long versionRuntimeCuDNN() const override;
long versionCuDNNFrontend() const override;
long versionMIOpen() const override; long versionMIOpen() const override;
std::string showConfig() const override; std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override; double batchnormMinEpsilonCuDNN() const override;

View File

@ -3,7 +3,6 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <array>
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
@ -137,9 +136,9 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
"Weight strides: ", t.strides(), "\n", "Weight strides: ", t.strides(), "\n",
"cuDNN suggested memory_format: ", memory_format); "cuDNN suggested memory_format: ", memory_format);
std::array<int, CUDNN_DIM_MAX> size; int size[CUDNN_DIM_MAX];
for (const auto i : c10::irange(dim)) { for (const auto i : c10::irange(dim)) {
size[i] = static_cast<int>(t.size(i)); size[i] = (int) t.size(i);
} }
for (const auto i : c10::irange(dim, pad)) { for (const auto i : c10::irange(dim, pad)) {
size[i] = 1; size[i] = 1;
@ -157,7 +156,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo
default: default:
TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters"); TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters");
} }
set(getDataType(t), static_cast<int>(dim), size.data(), filter_format); set(getDataType(t), static_cast<int>(dim), size, filter_format);
} }
std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) { std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) {

Some files were not shown because too many files have changed in this diff Show More