Compare commits

..

1 Commits

Author SHA1 Message Date
8cef91fb74 prints for static inputs indices 2025-11-04 13:19:14 -08:00
1068 changed files with 21059 additions and 36266 deletions

View File

@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
ARG DEVTOOLSET_VERSION=13
ARG DEVTOOLSET_VERSION=11
RUN yum -y update
RUN yum -y install epel-release
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
RUN yum -y install glibc-langpack-en
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
# Just add everything as a safe.directory for git since these will be used in multiple places with git
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
@ -41,7 +41,6 @@ RUN bash ./install_conda.sh && rm install_conda.sh
# Install CUDA
FROM base as cuda
ARG CUDA_VERSION=12.6
ARG DEVTOOLSET_VERSION=13
RUN rm -rf /usr/local/cuda-*
ADD ./common/install_cuda.sh install_cuda.sh
COPY ./common/install_nccl.sh install_nccl.sh
@ -51,8 +50,7 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
# Preserve CUDA_VERSION for the builds
ENV CUDA_VERSION=${CUDA_VERSION}
# Make things in our path by default
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
FROM cuda as cuda12.6
RUN bash ./install_cuda.sh 12.6
@ -70,22 +68,8 @@ FROM cuda as cuda13.0
RUN bash ./install_cuda.sh 13.0
ENV DESIRED_CUDA=13.0
FROM ${ROCM_IMAGE} as rocm_base
ARG DEVTOOLSET_VERSION=13
ENV LC_ALL en_US.UTF-8
ENV LANG en_US.UTF-8
ENV LANGUAGE en_US.UTF-8
# Install devtoolset on ROCm base image
RUN yum -y update && \
yum -y install epel-release && \
yum -y install glibc-langpack-en && \
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
RUN git config --global --add safe.directory '*'
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
FROM rocm_base as rocm
FROM ${ROCM_IMAGE} as rocm
ARG PYTORCH_ROCM_ARCH
ARG DEVTOOLSET_VERSION=13
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
ADD ./common/install_mkl.sh install_mkl.sh
RUN bash ./install_mkl.sh && rm install_mkl.sh
@ -104,7 +88,6 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
# Final step
FROM ${BASE_TARGET} as final
ARG DEVTOOLSET_VERSION=13
COPY --from=openssl /opt/openssl /opt/openssl
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
COPY --from=conda /opt/conda /opt/conda

View File

@ -36,7 +36,11 @@ case ${DOCKER_TAG_PREFIX} in
;;
rocm*)
BASE_TARGET=rocm
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
;;
*)
@ -59,7 +63,7 @@ docker build \
--target final \
--progress plain \
--build-arg "BASE_TARGET=${BASE_TARGET}" \
--build-arg "DEVTOOLSET_VERSION=13" \
--build-arg "DEVTOOLSET_VERSION=11" \
${EXTRA_BUILD_ARGS} \
-t ${tmp_tag} \
$@ \

View File

@ -168,18 +168,6 @@ case "$tag" in
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-py3.11-clang12)
ANACONDA_PYTHON_VERSION=3.11
CLANG_VERSION=12
VISION=no
TRITON=no
;;
pytorch-linux-jammy-py3.12-clang12)
ANACONDA_PYTHON_VERSION=3.12
CLANG_VERSION=12
VISION=no
TRITON=no
;;
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
if [[ $tag =~ "jammy" ]]; then
ANACONDA_PYTHON_VERSION=3.10
@ -207,9 +195,9 @@ case "$tag" in
NINJA_VERSION=1.9.0
TRITON=yes
;;
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=13
GCC_VERSION=11
VISION=yes
XPU_VERSION=2025.2
NINJA_VERSION=1.9.0
@ -260,12 +248,6 @@ case "$tag" in
HALIDE=yes
TRITON=yes
;;
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=11
PALLAS=yes
;;
pytorch-linux-jammy-py3.12-triton-cpu)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.12
@ -279,9 +261,9 @@ case "$tag" in
PYTHON_VERSION=3.10
CUDA_VERSION=12.8.1
;;
pytorch-linux-jammy-aarch64-py3.10-gcc13)
pytorch-linux-jammy-aarch64-py3.10-gcc11)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=13
GCC_VERSION=11
ACL=yes
VISION=yes
OPENBLAS=yes
@ -289,19 +271,9 @@ case "$tag" in
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-clang21)
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
CLANG_VERSION=21
ACL=yes
VISION=yes
OPENBLAS=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=13
GCC_VERSION=11
ACL=yes
VISION=yes
OPENBLAS=yes
@ -387,7 +359,6 @@ docker build \
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
--build-arg "EXECUTORCH=${EXECUTORCH}" \
--build-arg "HALIDE=${HALIDE}" \
--build-arg "PALLAS=${PALLAS}" \
--build-arg "XPU_VERSION=${XPU_VERSION}" \
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
--build-arg "ACL=${ACL:-}" \

View File

@ -1 +0,0 @@
0.8.0

View File

@ -1 +1 @@
bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd

View File

@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
# work around ubuntu apt-get conflicts
sudo apt-get -y -f install
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
if [[ $CLANG_VERSION -ge 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
if [[ $CLANG_VERSION == 18 ]]; then
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
fi
fi

View File

@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then
# Need the official toolchain repo to get alternate packages
add-apt-repository ppa:ubuntu-toolchain-r/test
apt-get update
apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION
apt-get install -y g++-$GCC_VERSION
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50
# Cleanup package manager
apt-get autoclean && apt-get clean

View File

@ -1,40 +0,0 @@
#!/bin/bash
set -ex
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
# Get the pinned JAX version (same for all CUDA versions)
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
function install_jax_12() {
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
}
function install_jax_13() {
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
}
# idiomatic parameter and option handling in sh
while test $# -gt 0
do
case "$1" in
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
;;
13.0|13.0.*) install_jax_13;
;;
*) echo "bad argument $1"; exit 1
;;
esac
shift
done

View File

@ -1,56 +0,0 @@
#!/bin/bash
# Script used only in CD pipeline
set -ex
# install dependencies
dnf -y install gmp-devel libmpc-devel texinfo flex bison
cd /usr/local/src
# fetch source for gcc 13
git clone --depth 1 --single-branch -b releases/gcc-13.3.0 https://github.com/gcc-mirror/gcc.git gcc-13.3.0
mkdir -p gcc-13.3.0/build-gomp
cd gcc-13.3.0/build-gomp
# configure gcc build
# I got these flags by:
# 1. downloading the source rpm for gcc-11 on AlmaLinux 8 container
# dnf install -y dnf-plugins-core rpmdevtools
# dnf download --source libgomp
# 2. extracting the gcc.spec from the source.
# rpmdev-extract gcc-xx.src.rpm
# 3. extracting optflags and ld_flags from gcc.spec:
# rpm --eval '%{optflags}'
# rpm --eval '%{build_ldflags}'
#
# I had to remove the following flags because they didn't compile for this version of libgomp:
# -Werror=format-security
# -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1
# -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1
#
# I added -march=armv8-a -mtune=generic to make them explicit. I don't think they're strictly needed.
OPT_FLAGS='-O2 -march=armv8-a -mtune=generic'\
' -fexceptions -g -grecord-gcc-switches -pipe -Wall'\
' -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS'\
' -fstack-protector-strong -fasynchronous-unwind-tables'\
' -fstack-clash-protection'
LDFLAGS='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now'
CFLAGS="$OPT_FLAGS" \
CXXFLAGS="$OPT_FLAGS" \
LDFLAGS="$LDFLAGS" \
../configure \
--prefix=/usr \
--libdir=/usr/lib64 \
--enable-languages=c,c++ \
--disable-multilib \
--disable-bootstrap \
--enable-libgomp
# only build libgomp
make -j$(nproc) all-target-libgomp
make install-target-libgomp

View File

@ -10,7 +10,6 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
OPENBLAS_BUILD_FLAGS="
CC=gcc
NUM_THREADS=128
USE_OPENMP=1
NO_SHARED=0

View File

@ -9,7 +9,7 @@ set -xe
function install_ubuntu() {
. /etc/os-release
if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then
if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then
echo "Ubuntu version ${VERSION_CODENAME} not supported"
exit
fi
@ -35,24 +35,25 @@ function install_ubuntu() {
# The xpu-smi packages
apt-get install -y flex bison xpu-smi
# Compute and Media Runtimes
if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
# Compute and Media Runtimes
apt-get install -y \
intel-opencl-icd libze-intel-gpu1 libze1 \
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
intel-opencl-icd intel-level-zero-gpu level-zero \
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
else # jammy
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
# Development Packages
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
else # rolling driver
apt-get install -y \
intel-opencl-icd libze-intel-gpu1 libze1 \
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
fi
# Development Packages
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
# Install Intel Support Packages
apt-get install -y ${XPU_PACKAGES}
@ -65,7 +66,7 @@ function install_ubuntu() {
function install_rhel() {
. /etc/os-release
if [[ "${ID}" == "rhel" ]]; then
if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
echo "RHEL version ${VERSION_ID} not supported"
exit
fi
@ -146,7 +147,7 @@ function install_sles() {
XPU_DRIVER_VERSION=""
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
# Use GPU driver LTS releases
XPU_DRIVER_VERSION="/lts/2523"
XPU_DRIVER_VERSION="/lts/2350"
fi
# Default use Intel® oneAPI Deep Learning Essentials 2025.1

View File

@ -49,7 +49,11 @@ case ${DOCKER_TAG_PREFIX} in
fi
BASE_TARGET=rocm
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
;;
*)

View File

@ -50,10 +50,6 @@ RUN rm install_ninja.sh
ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
# Build a newer version of libgomp than that supported in in Almalinux 8.
COPY ./common/install_libgomp.sh install_libgomp.sh
RUN bash ./install_libgomp.sh && rm install_libgomp.sh
# git236+ would refuse to run git commands in repos owned by other users
# Which causes version check to fail, as pytorch repo is bind-mounted into the image
# Override this behaviour by treating every folder as safe

View File

@ -87,7 +87,11 @@ case ${image} in
MANY_LINUX_VERSION="2_28"
DEVTOOLSET_VERSION="11"
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
;;
manylinux2_28-builder:xpu)

View File

@ -1,11 +1,15 @@
sphinx==7.2.6
sphinx==5.3.0
#Description: This is used to generate PyTorch docs
#Pinned versions: 7.2.6
#Pinned versions: 5.3.0
pytorch_sphinx_theme2==0.2.0
#Description: This is needed to generate PyTorch docs
#Pinned versions: 0.2.0
standard-imghdr==3.13.0; python_version >= "3.13"
#Description: This is needed by Sphinx, so it needs to be added here.
# The reasons are as follows:
# 1) This module has been removed from the Python standard library since Python 3.13(https://peps.python.org/pep-0594/#imghdr);
# 2) The current version of Sphinx (5.3.0) is not compatible with Python 3.13.
# Once Sphinx is upgraded to a version compatible with Python 3.13 or later, we can remove this dependency.
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@71e55749be14ceb56e7f8211a9fb649866b87ad4#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
# something related to Docker setup. We can investigate this later.
@ -32,17 +36,17 @@ tensorboard==2.18.0 ; python_version >= "3.13"
#Description: This is used to generate PyTorch docs
#Pinned versions: 2.13.0
breathe==4.36.0
breathe==4.34.0
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 4.36.0
#Pinned versions: 4.34.0
exhale==0.3.7
exhale==0.2.3
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.3.7
#Pinned versions: 0.2.3
docutils==0.20
docutils==0.16
#Description: This is used to generate PyTorch C++ docs
#Pinned versions: 0.20
#Pinned versions: 0.16
bs4==0.0.1
#Description: This is used to generate PyTorch C++ docs
@ -52,13 +56,13 @@ IPython==8.12.0
#Description: This is used to generate PyTorch functorch docs
#Pinned versions: 8.12.0
myst-nb==1.3.0
myst-nb==0.17.2
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Pinned versions: 1.3.0
#Pinned versions: 0.17.2
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
python-etcd==0.4.5
sphinx-copybutton==0.5.0
sphinx-design==0.6.1
sphinx-design==0.4.0
sphinxcontrib-mermaid==1.0.0
myst-parser==4.0.1
myst-parser==0.18.1

View File

@ -1 +1 @@
3.5.1
3.5.0

View File

@ -143,15 +143,6 @@ COPY ci_commit_pins/halide.txt halide.txt
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
RUN rm install_halide.sh common_utils.sh halide.txt
ARG PALLAS
ARG CUDA_VERSION
# Install JAX with CUDA support (for Pallas)
COPY ./common/install_jax.sh install_jax.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
ARG ONNX
# Install ONNX dependencies
COPY ./common/install_onnx.sh ./common/common_utils.sh ./

View File

@ -8,11 +8,9 @@ from abc import ABC, abstractmethod
try:
from collections.abc import Callable # Python 3.11+
from typing import Any, Required, TypedDict
from typing import Any, Callable, Required, TypedDict # Python 3.11+
except ImportError:
from collections.abc import Callable
from typing import Any, TypedDict
from typing import Any, Callable, TypedDict
from typing_extensions import Required # Fallback for Python <3.11

View File

@ -30,6 +30,7 @@ into a tarball, with the following structure:
More specifically, `build_magma.sh` copies over the relevant files from the `package_files` directory depending on the ROCm version.
Outputted binaries should be in the `output` folder.
## Pushing
Packages can be uploaded to an S3 bucket using:

View File

@ -6,8 +6,8 @@ set -eou pipefail
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
# https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
# post merge of https://github.com/icl-utk-edu/magma/pull/65
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
# Folders for the build
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
# Fetch magma sources and verify checksum
pushd ${PACKAGE_DIR}
git clone https://github.com/jeffdaily/magma
git clone https://github.com/icl-utk-edu/magma
pushd magma
git checkout ${MAGMA_VERSION}
popd

View File

@ -168,16 +168,14 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
# shellcheck disable=SC1091
source /opt/intel/oneapi/compiler/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/umf/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/ccl/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/mpi/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/pti/latest/env/vars.sh
# Enable XCCL build
export USE_XCCL=1
export USE_MPI=0
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
export USE_KINETO=0
export TORCH_XPU_ARCH_LIST=pvc
fi

View File

@ -96,6 +96,7 @@ function pip_build_and_install() {
python3 -m pip wheel \
--no-build-isolation \
--no-deps \
--no-use-pep517 \
-w "${wheel_dir}" \
"${build_target}"
fi
@ -307,28 +308,6 @@ function install_torchao() {
pip_build_and_install "git+https://github.com/pytorch/ao.git@${commit}" dist/ao
}
function install_flash_attn_cute() {
echo "Installing FlashAttention CuTe from GitHub..."
# Grab latest main til we have a pinned commit
local flash_attn_commit
flash_attn_commit=$(git ls-remote https://github.com/Dao-AILab/flash-attention.git HEAD | cut -f1)
# Clone the repo to a temporary directory
rm -rf flash-attention-build
git clone --depth 1 --recursive https://github.com/Dao-AILab/flash-attention.git flash-attention-build
pushd flash-attention-build
git checkout "${flash_attn_commit}"
# Install only the 'cute' sub-directory
pip_install -e flash_attn/cute/
popd
# remove the local repo
rm -rf flash-attention-build
echo "FlashAttention CuTe installation complete."
}
function print_sccache_stats() {
echo 'PyTorch Build Statistics'
sccache --show-stats

View File

@ -89,41 +89,23 @@ if [ "$is_main_doc" = true ]; then
make coverage
# Now we have the coverage report, we need to make sure it is empty.
# Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row
# showing the undocumented count in the third column.
# Example: | TOTAL | 99.83% | 2 |
# Count the number of lines in the file and turn that number into a variable
# $lines. The `cut -f1 ...` is to only parse the number, not the filename
# Skip the report header by subtracting 2: the header will be output even if
# there are no undocumented items.
#
# Also: see docs/source/conf.py for "coverage_ignore*" items, which should
# be documented then removed from there.
# Extract undocumented count from TOTAL row in Sphinx 7.2.6 statistics table
# The table format is: | Module | Coverage | Undocumented |
# Extract the third column (undocumented count) from the TOTAL row
undocumented=$(grep "| TOTAL" build/coverage/python.txt | awk -F'|' '{print $4}' | tr -d ' ')
if [ -z "$undocumented" ] || ! [[ "$undocumented" =~ ^[0-9]+$ ]]; then
lines=$(wc -l build/coverage/python.txt 2>/dev/null |cut -f1 -d' ')
undocumented=$((lines - 2))
if [ $undocumented -lt 0 ]; then
echo coverage output not found
exit 1
elif [ "$undocumented" -gt 0 ]; then
set +x # Disable command echoing for cleaner output
echo ""
echo "====================="
echo "UNDOCUMENTED OBJECTS:"
echo "====================="
echo ""
# Find the line number of the TOTAL row and print only what comes after it
total_line=$(grep -n "| TOTAL" build/coverage/python.txt | cut -d: -f1)
if [ -n "$total_line" ]; then
# Print only the detailed list (skip the statistics table)
tail -n +$((total_line + 2)) build/coverage/python.txt
else
# Fallback to showing entire file if TOTAL line not found
cat build/coverage/python.txt
fi
echo ""
elif [ $undocumented -gt 0 ]; then
echo undocumented objects found:
cat build/coverage/python.txt
echo "Make sure you've updated relevant .rsts in docs/source!"
echo "You can reproduce locally by running 'cd docs && make coverage && tail -n +\$((grep -n \"| TOTAL\" build/coverage/python.txt | cut -d: -f1) + 2)) build/coverage/python.txt'"
set -x # Re-enable command echoing
echo "You can reproduce locally by running 'cd docs && make coverage && cat build/coverage/python.txt'"
exit 1
fi
else

View File

@ -208,8 +208,6 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
source /opt/intel/oneapi/ccl/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/mpi/latest/env/vars.sh
# shellcheck disable=SC1091
source /opt/intel/oneapi/pti/latest/env/vars.sh
# Check XPU status before testing
timeout 30 xpu-smi discovery || true
fi
@ -344,18 +342,8 @@ test_python_smoke() {
}
test_python_smoke_b200() {
# Targeted smoke tests for B200 including FlashAttention CuTe coverage
install_flash_attn_cute
time python test/run_test.py \
--include \
test_matmul_cuda \
test_scaled_matmul_cuda \
inductor/test_fp8 \
nn/attention/test_fa4 \
nn/attention/test_open_registry \
inductor/test_flex_flash \
$PYTHON_TEST_EXTRA_OPTION \
--upload-artifacts-while-running
# Targeted smoke tests for B200 - staged approach to avoid too many failures
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}
@ -836,11 +824,6 @@ test_inductor_halide() {
assert_git_not_dirty
}
test_inductor_pallas() {
python test/run_test.py --include inductor/test_pallas.py --verbose
assert_git_not_dirty
}
test_inductor_triton_cpu() {
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
assert_git_not_dirty
@ -1741,8 +1724,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
test_inductor_halide
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
test_inductor_pallas
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
test_inductor_triton_cpu
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then

View File

@ -70,7 +70,7 @@ sccache --zero-stats
sccache --show-stats
# Build the wheel
python -m build --wheel --no-isolation
python -m build --wheel --no-build-isolation
if ($LASTEXITCODE -ne 0) { exit 1 }
# Install the wheel locally

View File

@ -1,11 +1,11 @@
name: 🚀 New Feature for Release
name: 🚀 Release highlight for proposed Feature
description: Submit a Release highlight for proposed Feature
labels: ["release-feature-request"]
body:
- type: textarea
attributes:
label: New Feature for Release
label: Release highlight for proposed Feature
description: >
Example: “A torch.special module, analogous to SciPy's special module.”
- type: input

View File

@ -63,7 +63,7 @@ self-hosted-runner:
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.2
- linux.rocm.gpu.gfx942.4
- linux.rocm.gfx942.docker-cache
- rocm-docker
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
- macos-m1-stable
- macos-m1-14

View File

@ -38,9 +38,9 @@ runs:
run: |
python3 .github/scripts/pytest_cache.py \
--download \
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
--pr_identifier "$GITHUB_REF" \
--job_identifier "$JOB_IDENTIFIER" \
--temp_dir "$RUNNER_TEMP" \
--repo "$REPO" \
--bucket "$BUCKET" \
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
--pr_identifier $GITHUB_REF \
--job_identifier $JOB_IDENTIFIER \
--temp_dir $RUNNER_TEMP \
--repo $REPO \
--bucket $BUCKET \

View File

@ -47,11 +47,11 @@ runs:
run: |
python3 .github/scripts/pytest_cache.py \
--upload \
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
--pr_identifier "$GITHUB_REF" \
--job_identifier "$JOB_IDENTIFIER" \
--sha "$SHA" \
--test_config "$TEST_CONFIG" \
--shard "$SHARD" \
--repo "$REPO" \
--temp_dir "$RUNNER_TEMP" \
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
--pr_identifier $GITHUB_REF \
--job_identifier $JOB_IDENTIFIER \
--sha $SHA \
--test_config $TEST_CONFIG \
--shard $SHARD \
--repo $REPO \
--temp_dir $RUNNER_TEMP \

View File

@ -1 +1 @@
07b6cbde121417a70e4dc871adb6d27030e0ce3f
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2

View File

@ -1 +1 @@
acccf86477759b2d3500f1ae1be065f7b1e409ec
cfbc5c2f1c798991715a6b06bb3ce46478c4487c

View File

@ -1 +1 @@
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9

View File

@ -1,125 +0,0 @@
# PyTorch Copilot Instructions
This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively.
## Architecture Overview
### Core Components
- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality
- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd
- `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse)
- `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry
- **torch/** - Python bindings and public API
- `torch/csrc/` - C++ Python bindings (hand-written and generated)
- `torch/csrc/autograd/` - Reverse-mode automatic differentiation
- `torch/csrc/jit/` - TorchScript JIT compiler
- **torchgen/** - Code generation tooling that reads `native_functions.yaml`
- **tools/** - Build scripts, autograd derivatives, code generation
### The Code Generation Workflow
**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file:
1. Declares operator signatures, variants (function/method), and dispatch behavior
2. Gets processed by `torchgen/` to generate C++/Python bindings
3. Produces headers in `build/aten/src/ATen/` during compilation
Example entry structure:
```yaml
- func: my_op(Tensor self, Scalar alpha=1) -> Tensor
variants: function, method
dispatch:
CPU: my_op_cpu
CUDA: my_op_cuda
```
After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`).
## Development Workflows
### Building from Source
**Never run `setup.py` directly** - use pip with editable install:
```bash
python -m pip install --no-build-isolation -v -e .
```
Speed up builds:
- `DEBUG=1` - Debug symbols with `-g -O0`
- `USE_CUDA=0` - Skip CUDA compilation
- `BUILD_TEST=0` - Skip C++ test binaries
- Install `ninja` (`pip install ninja`) for faster builds
- Use `ccache` for incremental compilation caching
Rebuild specific targets: `(cd build && ninja <target>)`
### Testing
**Critical**: DO NOT run entire test suites. Run specific tests only:
```bash
python test/test_torch.py TestTorch.test_specific_case
```
**Test structure**: All tests use `torch.testing._internal.common_utils`:
```python
from torch.testing._internal.common_utils import run_tests, TestCase
class TestFeature(TestCase):
def test_something(self):
# Use self.assertEqual for tensor comparisons
pass
if __name__ == "__main__":
run_tests()
```
**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file.
### Linting
Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes)
## Project-Specific Conventions
### Memory and Storage
- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs)
- CUDA device info lives in storage objects
### Python-C++ Integration (`torch/csrc/`)
- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors
- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr`
- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion
### Dispatch System
- PyTorch uses operator dispatch to route calls to backend-specific kernels
- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops
- See `aten/src/ATen/native/README.md` for dispatch keyword guidance
## Git Workflow (AI Agent Specific)
When preparing PRs from this environment:
```bash
git stash -u
git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch
git stash pop
# Resolve conflicts if necessary
```
## Common Gotchas
1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml`
2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI
3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux)
4. **No internet access** - DO NOT attempt to install dependencies during development
## Key Files Reference
- `AGENTS.md` - Instructions specific to AI coding agents
- `CONTRIBUTING.md` - Comprehensive human contributor guide
- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript)
- `aten/src/ATen/native/README.md` - Operator implementation guide
- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd
## Performance Debugging
Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation.

22
.github/labeler.yml vendored
View File

@ -138,8 +138,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -149,8 +148,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -160,21 +158,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm
"ciflow/mps":
- aten/src/ATen/mps/**
- aten/src/ATen/native/mps/**
- torch/_inductor/codegen/mps.py
- test/test_mps.py
- test/inductor/test_mps_basic.py
"ciflow/h100-symm-mem":
- torch/csrc/distributed/c10d/symm_mem/**
- torch/distributed/_symmetric_memory/**
- test/distributed/**/*mem*
- test/distributed/**/*mem*/**

View File

@ -10,4 +10,3 @@
pathFilter:
- 'torch/csrc/inductor/aoti_torch/c/*'
- 'torch/csrc/inductor/aoti_torch/generated/*'
- 'torch/csrc/stable/c/*'

View File

@ -2,8 +2,8 @@ tracking_issue: 24422
ciflow_tracking_issue: 64124
ciflow_push_tags:
- ciflow/b200
- ciflow/b200-distributed
- ciflow/b200-symm-mem
- ciflow/b200-distributed
- ciflow/binaries
- ciflow/binaries_libtorch
- ciflow/binaries_wheel
@ -22,8 +22,6 @@ ciflow_push_tags:
- ciflow/inductor-perf-test-nightly-xpu
- ciflow/inductor-periodic
- ciflow/inductor-rocm
- ciflow/inductor-rocm-mi200
- ciflow/inductor-rocm-mi300
- ciflow/linux-aarch64
- ciflow/mps
- ciflow/nightly
@ -35,13 +33,11 @@ ciflow_push_tags:
- ciflow/quantization-periodic
- ciflow/riscv64
- ciflow/rocm
- ciflow/rocm-mi200
- ciflow/rocm-mi300
- ciflow/rocm-mi355
- ciflow/rocm-navi31
- ciflow/s390
- ciflow/slow
- ciflow/slow-rocm-mi200
- ciflow/torchbench
- ciflow/triton_binaries
- ciflow/trunk

View File

@ -1,11 +1,10 @@
# Delete old branches
import os
import re
from collections.abc import Callable
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Any
from typing import Any, Callable
from github_utils import gh_fetch_json_dict, gh_graphql
from gitutils import GitRepo

View File

@ -8,11 +8,10 @@ import re
import subprocess
import sys
import warnings
from collections.abc import Callable
from enum import Enum
from functools import cache
from logging import info
from typing import Any, Optional
from typing import Any, Callable, Optional
from urllib.request import Request, urlopen
import yaml

View File

@ -11,8 +11,7 @@ import sys
import time
import urllib
import urllib.parse
from collections.abc import Callable
from typing import Any, Optional
from typing import Any, Callable, Optional
from urllib.request import Request, urlopen

View File

@ -3,9 +3,8 @@
import json
import os
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, cast, Optional, Union
from typing import Any, Callable, cast, Optional, Union
from urllib.error import HTTPError
from urllib.parse import quote
from urllib.request import Request, urlopen

View File

@ -4,10 +4,10 @@ import os
import re
import tempfile
from collections import defaultdict
from collections.abc import Callable, Iterator
from collections.abc import Iterator
from datetime import datetime
from functools import wraps
from typing import Any, cast, Optional, TypeVar, Union
from typing import Any, Callable, cast, Optional, TypeVar, Union
T = TypeVar("T")

View File

@ -34,9 +34,6 @@ python3 torch/utils/data/datapipes/gen_pyi.py
# Also check generated pyi files
find torch -name '*.pyi' -exec git add --force -- "{}" +
# Print current environment
python3 -m pip freeze
RC=0
# Run lintrunner on all files
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then

View File

@ -17,12 +17,12 @@ import re
import time
import urllib.parse
from collections import defaultdict
from collections.abc import Callable, Iterable
from collections.abc import Iterable
from dataclasses import dataclass
from functools import cache
from pathlib import Path
from re import Pattern
from typing import Any, cast, NamedTuple, Optional
from typing import Any, Callable, cast, NamedTuple, Optional
from warnings import warn
import yaml

View File

@ -97,8 +97,8 @@ jobs:
shell: bash
run: |
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
if [[ $ngpu -lt 4 ]]; then
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
exit 1
fi

View File

@ -344,21 +344,5 @@ jobs:
if-no-files-found: ignore
path: ./**/core.[1-9]*
- name: Authenticate with AWS
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
# The max duration enforced by the server side
role-duration-seconds: 18000
aws-region: us-east-1
- name: Upload the benchmark results
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
with:
benchmark-results-dir: test/test-reports
dry-run: false
schema-version: v3
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Teardown XPU
uses: ./.github/actions/teardown-xpu

View File

@ -56,8 +56,6 @@ jobs:
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
pytorch-linux-jammy-py3.10-clang12,
pytorch-linux-jammy-py3.11-clang12,
pytorch-linux-jammy-py3.12-clang12,
pytorch-linux-jammy-py3.13-clang12,
pytorch-linux-jammy-py3.14-clang12,
pytorch-linux-jammy-rocm-n-py3,
@ -67,10 +65,9 @@ jobs:
pytorch-linux-jammy-py3.10-gcc11,
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-noble-xpu-n-py3,
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
pytorch-linux-jammy-xpu-n-py3,
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
pytorch-linux-jammy-py3-clang18-asan,
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,
@ -80,11 +77,9 @@ jobs:
pytorch-linux-noble-riscv64-py3.12-gcc14
]
include:
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
runner: linux.arm64.m7g.4xlarge
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
runner: linux.arm64.m7g.4xlarge
timeout-minutes: 600
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358
@ -119,22 +114,6 @@ jobs:
with:
docker-image: ${{ steps.build-docker-image.outputs.docker-image }}
- name: Generate output
if: contains(matrix.docker-image-name, 'rocm')
id: generate_output
run: |
docker_image_name="${{ matrix.docker-image-name }}"
docker_image_tag="${{ steps.build-docker-image.outputs.docker-image }}"
echo "${docker_image_name}=${docker_image_tag}" >> docker-builds-output-${docker_image_name}.txt
- name: Upload artifacts
uses: actions/upload-artifact@v4.4.0
if: contains(matrix.docker-image-name, 'rocm')
with:
name: docker-builds-artifacts-${{ matrix.docker-image-name }}
retention-days: 14
path: ./docker-builds-output-${{ matrix.docker-image-name }}.txt
- uses: nick-fields/retry@7152eba30c6575329ac0576536151aca5a72780e # v3.0.0
name: Push to https://ghcr.io/
id: push-to-ghcr-io

View File

@ -0,0 +1,55 @@
name: docker-cache-mi300
on:
# run every 6 hours
schedule:
- cron: 0 0,6,12,18 * * *
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
docker-cache:
if: github.repository_owner == 'pytorch'
runs-on: rocm-docker
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
with:
no-sudo: true
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
with:
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
push: false
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
- name: Tar and upload to S3 bucket
run: |
sudo docker save -o ~/docker-data/pytorch/pytorch_docker_image.tar ${{ steps.calculate-docker-image.outputs.docker-image }}
sudo rclone copy -P --s3-upload-concurrency 64 --s3-chunk-size 200M --s3-upload-cutoff 300M ~/docker-data/pytorch/pytorch_docker_image.tar oci:pytorchbucket0002/pytorch_docker_image --progress

View File

@ -1,105 +0,0 @@
name: docker-cache-rocm
on:
workflow_run:
workflows: [docker-builds]
branches: [main, release]
types:
- completed
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
actions: read
jobs:
download-docker-builds-artifacts:
if: github.repository_owner == 'pytorch'
name: download-docker-builds-artifacts
runs-on: ubuntu-latest
outputs:
pytorch-linux-jammy-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}
pytorch-linux-noble-rocm-n-py3: ${{ steps.process-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}
pytorch-linux-jammy-rocm-n-py3-benchmarks: ${{ steps.process-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}
steps:
- name: Download artifacts
uses: actions/download-artifact@v4.1.7
with:
run-id: ${{ github.event.workflow_run.id }}
path: ./docker-builds-artifacts
merge-multiple: true
github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Process artifacts
id: process-artifacts
run: |
ls -R ./docker-builds-artifacts
cat ./docker-builds-artifacts/*txt >> "${GITHUB_OUTPUT}"
cat "${GITHUB_OUTPUT}"
docker-cache:
if: github.repository_owner == 'pytorch'
needs: download-docker-builds-artifacts
strategy:
fail-fast: false
matrix:
runner: [linux.rocm.gfx942.docker-cache]
docker-image: [
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}",
"${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3-benchmarks }}"
]
runs-on: "${{ matrix.runner }}"
steps:
- name: debug
run: |
JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}"
echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}"
- name: configure aws credentials
id: aws_creds
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
aws-region: us-east-1
role-duration-seconds: 18000
- name: Login to Amazon ECR
id: login-ecr
continue-on-error: false
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
- name: Generate ghrc.io tag
id: ghcr-io-tag
run: |
ecr_image="${{ matrix.docker-image }}"
ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}"
echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT"
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }}
- name: Save as tarball
run: |
docker_image_tag=${{ matrix.docker-image }}
docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":"
docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-"
ref_name=${{ github.event.workflow_run.head_branch }}
if [[ $ref_name =~ "release/" ]]; then
ref_suffix="release"
elif [[ $ref_name == "main" ]]; then
ref_suffix="main"
else
echo "Unexpected branch in ref_name: ${ref_name}" && exit 1
fi
docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }}
# mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention
docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }}
mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar

View File

@ -37,6 +37,7 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: "linux.c7i.12xlarge"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '9.0'

View File

@ -72,7 +72,7 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.arm64.m7g.4xlarge
build-environment: linux-jammy-aarch64-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },

View File

@ -83,8 +83,8 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
@ -117,7 +117,7 @@ jobs:
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-noble-xpu-n-py3.10
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
@ -137,7 +137,7 @@ jobs:
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-noble-xpu-n-py3.10
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}

View File

@ -7,7 +7,6 @@ on:
- release/*
tags:
- ciflow/inductor-rocm/*
- ciflow/inductor-rocm-mi300/*
workflow_dispatch:
concurrency:

View File

@ -1,13 +1,13 @@
name: inductor-rocm-mi200
name: inductor-rocm
on:
schedule:
- cron: 0 */3 * * *
- cron: 0 * * * *
push:
branches:
- release/*
tags:
- ciflow/inductor-rocm-mi200/*
- ciflow/inductor-rocm/*
workflow_dispatch:
concurrency:

View File

@ -81,32 +81,6 @@ jobs:
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
secrets: inherit
inductor-pallas-build:
name: inductor-pallas-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
cuda-arch-list: '8.9'
runner: linux.8xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" },
]}
secrets: inherit
inductor-pallas-test:
name: inductor-pallas-test
uses: ./.github/workflows/_linux-test.yml
needs: inductor-pallas-build
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
secrets: inherit
inductor-triton-cpu-build:
name: inductor-triton-cpu-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -33,7 +33,7 @@ jobs:
with:
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-jammy-aarch64-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.m7g.4xlarge
test-matrix: |
{ include: [

View File

@ -5,11 +5,9 @@ on:
- cron: 0 0 * * *
push:
tags:
# NOTE: Doc build pipelines should only get triggered on:
# Major or minor release candidates builds
- v[0-9]+.[0-9]+.0+-rc[0-9]+
# Final RC for major, minor and patch releases
- v[0-9]+.[0-9]+.[0-9]+
# NOTE: Doc build pipelines should only get triggered on release candidate builds
# Release candidate tags look like: v1.11.0-rc1
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
- ciflow/nightly/*
workflow_dispatch:

View File

@ -60,7 +60,7 @@ jobs:
with:
build-environment: linux-jammy-aarch64-py3.10
runner: linux.arm64.m7g.4xlarge
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
test-matrix: |
{ include: [
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },

View File

@ -11,6 +11,7 @@ on:
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
push:
tags:
- ciflow/periodic/*
- ciflow/periodic-rocm-mi200/*
branches:
- release/*

View File

@ -11,7 +11,6 @@ on:
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
push:
tags:
- ciflow/periodic/*
- ciflow/periodic-rocm-mi300/*
branches:
- release/*

View File

@ -342,16 +342,16 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
secrets: inherit
linux-noble-xpu-n-py3_10-build:
name: linux-noble-xpu-n-py3.10
linux-jammy-xpu-n-py3_10-build:
name: linux-jammy-xpu-n-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
# This should sync with the build in xpu.yml but xpu uses a larger runner
# sync-tag: linux-xpu-n-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },

View File

@ -6,7 +6,6 @@ on:
- main
- release/*
tags:
- ciflow/rocm/*
- ciflow/rocm-mi300/*
workflow_dispatch:
schedule:

View File

@ -1,16 +1,15 @@
name: rocm-mi200
name: rocm
on:
push:
branches:
- release/*
tags:
- ciflow/rocm-mi200/*
- ciflow/rocm/*
workflow_dispatch:
schedule:
- cron: 29 8 * * * # about 1:29am PDT
- cron: 0 */3 * * *
- cron: 0 * * * *
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}

View File

@ -1,81 +0,0 @@
# This workflow is dedicated to host slow jobs that are run only periodically because
# they are too slow to run in every commit. The list of slow tests can be found in
# https://github.com/pytorch/test-infra/blob/generated-stats/stats/slow-tests.json
name: slow-rocm-mi200
on:
push:
branches:
- release/*
tags:
- ciflow/slow/*
- ciflow/slow-rocm-mi200/*
schedule:
- cron: 0 */3 * * *
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
llm-td:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/llm_td_retrieval.yml
permissions:
id-token: write
contents: read
target-determination:
name: before-test
uses: ./.github/workflows/target_determination.yml
needs: llm-td
permissions:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -105,6 +105,36 @@ jobs:
test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3_10-clang18-asan-build:
name: linux-jammy-py3.10-clang18-asan
uses: ./.github/workflows/_linux-build.yml

View File

@ -5,9 +5,7 @@
# Flow:
# 1. Builds PyTorch with CUDA 12.8+ and sm100 architecture for B200
# 2. Runs smoke tests on linux.dgx.b200 runner
# 3. Tests executed are defined in .ci/pytorch/test.sh -> test_python_smoke_b200() function
# - Includes matmul, scaled_matmul, FP8, and FlashAttention CuTe tests
# - FlashAttention CuTe DSL is installed as part of test execution
# 3. Tests executed are defined in .ci/pytorch/test.sh -> test_python_smoke() function
#
# Triggered by:
# - Pull requests modifying this workflow file

View File

@ -41,6 +41,7 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '9.0'

View File

@ -1,83 +0,0 @@
name: trunk-rocm-mi300
on:
push:
branches:
- main
- release/*
workflow_dispatch:
schedule:
- cron: 29 8 * * * # about 1:29am PDT
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
llm-td:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/llm_td_retrieval.yml
permissions:
id-token: write
contents: read
target-determination:
name: before-test
uses: ./.github/workflows/target_determination.yml
needs: llm-td
permissions:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1.b" },
{ config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
{ config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
{ config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx942.4.b" },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -5,23 +5,21 @@ on:
workflows:
- pull
- trunk
- trunk-rocm-mi300
- periodic
- periodic-rocm-mi200
- periodic-rocm-mi300
- inductor
- unstable
- slow
- slow-rocm-mi200
- unstable-periodic
- inductor-periodic
- rocm-mi200
- rocm
- rocm-mi300
- rocm-mi355
- inductor-micro-benchmark
- inductor-micro-benchmark-x86
- inductor-cu124
- inductor-rocm-mi200
- inductor-rocm
- inductor-rocm-mi300
- mac-mps
- linux-aarch64

View File

@ -47,15 +47,15 @@ jobs:
]}
secrets: inherit
linux-noble-xpu-n-py3_10-build:
name: linux-noble-xpu-n-py3.10
linux-jammy-xpu-n-py3_10-build:
name: linux-jammy-xpu-n-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
sync-tag: linux-xpu-n-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
@ -74,17 +74,17 @@ jobs:
]}
secrets: inherit
linux-noble-xpu-n-py3_10-test:
name: linux-noble-xpu-n-py3.10
linux-jammy-xpu-n-py3_10-test:
name: linux-jammy-xpu-n-py3.10
uses: ./.github/workflows/_xpu-test.yml
needs: linux-noble-xpu-n-py3_10-build
needs: linux-jammy-xpu-n-py3_10-build
permissions:
id-token: write
contents: read
with:
build-environment: linux-noble-xpu-n-py3.10
docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }}
build-environment: linux-jammy-xpu-n-py3.10
docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }}
secrets: inherit
windows-xpu-n-1-build:

View File

@ -143,8 +143,7 @@ init_command = [
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
'numpy==2.3.4 ; python_version >= "3.14"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'pyrefly==0.36.2',
'sympy==1.13.3',
@ -186,8 +185,6 @@ include_patterns = [
'aten/src/ATen/native/nested/cuda/*.h',
'aten/src/ATen/native/nested/*.cpp',
'aten/src/ATen/native/nested/*.h',
'aten/src/ATen/xpu/**/*.h',
'aten/src/ATen/xpu/**/*.cpp',
'c10/**/*.cpp',
'c10/**/*.h',
'torch/*.h',
@ -1404,7 +1401,7 @@ init_command = [
'--dry-run={{DRYRUN}}',
'usort==1.0.8.post1',
'isort==6.0.1',
'ruff==0.14.4', # sync with RUFF
'ruff==0.13.1', # sync with RUFF
]
is_formatter = true
@ -1539,7 +1536,7 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'ruff==0.14.4', # sync with PYFMT
'ruff==0.13.1', # sync with PYFMT
]
is_formatter = true

View File

@ -234,17 +234,7 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON)
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
option(USE_LSAN "Use Leak Sanitizer" OFF)
option(USE_TSAN "Use Thread Sanitizer" OFF)
# Track whether USE_CUDA was explicitly set by the user (before option() is called)
# If USE_CUDA is already defined in cache, it means user explicitly set it
if(DEFINED CACHE{USE_CUDA})
set(_USE_CUDA_EXPLICITLY_SET TRUE)
else()
set(_USE_CUDA_EXPLICITLY_SET FALSE)
endif()
option(USE_CUDA "Use CUDA" ON)
option(USE_XPU "Use XPU" ON)
cmake_dependent_option(
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
@ -736,44 +726,6 @@ if(NOT DEFINED USE_BLAS)
set(USE_BLAS ON)
endif()
# Prioritized Text Linker Optimization
if(USE_PRIORITIZED_TEXT_FOR_LD)
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
execute_process(
COMMAND ${Python_EXECUTABLE}
${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py
--filein "${LINKER_SCRIPT_FILE_IN}"
--fout "${LINKER_SCRIPT_FILE_OUT}"
RESULT_VARIABLE _gen_result
OUTPUT_VARIABLE _gen_output
ERROR_VARIABLE _gen_error
)
if(NOT _gen_result EQUAL 0)
message(FATAL_ERROR
"Failed to generate linker script:\n${_gen_output}\n${_gen_error}")
endif()
append_cxx_flag_if_supported("-ffunction-sections" CMAKE_CXX_FLAGS)
append_cxx_flag_if_supported("-fdata-sections" CMAKE_CXX_FLAGS)
append_c_flag_if_supported("-ffunction-sections" CMAKE_C_FLAGS)
append_c_flag_if_supported("-fdata-sections" CMAKE_C_FLAGS)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
else()
if(LINUX AND CPU_AARCH64)
message(WARNING [[
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
]])
endif()
endif()
# Build libtorch mobile library, which contains ATen/TH ops and native support
# for TorchScript model, but doesn't contain not-yet-unified caffe2 ops;
if(INTERN_BUILD_MOBILE)
@ -1440,6 +1392,9 @@ if(BUILD_JNI)
add_subdirectory(android/pytorch_android)
endif()
include(cmake/Summary.cmake)
caffe2_print_configuration_summary()
# Parse custom debug info
if(DEFINED USE_CUSTOM_DEBINFO)
string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}")
@ -1479,5 +1434,56 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
DESTINATION "${CMAKE_INSTALL_BINDIR}")
endif()
include(cmake/Summary.cmake)
caffe2_print_configuration_summary()
if(USE_PRIORITIZED_TEXT_FOR_LD)
add_compile_options(
$<$<COMPILE_LANGUAGE:C,CXX>:-ffunction-sections>
$<$<COMPILE_LANGUAGE:C,CXX>:-fdata-sections>
)
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
add_custom_command(
OUTPUT "${LINKER_SCRIPT_FILE_OUT}"
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}"
DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}"
COMMENT "Generating prioritized text linker files"
VERBATIM
)
add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
if(BUILD_PYTHON)
set(LINKER_OPT_TARGETS torch_python)
endif()
if(NOT BUILD_LIBTORCHLESS)
list(APPEND LINKER_OPT_TARGETS torch_cpu c10)
if(USE_CUDA)
list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda)
endif()
if(USE_XPU)
list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu)
endif()
if(USE_ROCM)
list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip)
endif()
endif()
foreach(tgt IN LISTS LINKER_OPT_TARGETS)
if(TARGET ${tgt})
add_dependencies("${tgt}" generate_linker_script)
target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}")
set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
else()
message(WARNING "Requested target '${tgt}' for linker script optimization was not found.")
endif()
endforeach()
else()
if(LINUX AND CPU_AARCH64)
message(WARNING [[
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
]])
endif()
endif()

View File

@ -210,12 +210,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/test/inductor/test_flex_attention.py @drisspg
/test/inductor/test_flex_decoding.py @drisspg
# Low Precision & Grouped GEMMs
# Low Precision GEMMs
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
/test/test_scaled_matmul_cuda.py @drisspg @slayton58

View File

@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
- [Python Unit Testing](#python-unit-testing)
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
- [Local linting](#local-linting)
- [Running `pyrefly`](#running-pyrefly)
- [Running `mypy`](#running-mypy)
- [C++ Unit Testing](#c-unit-testing)
- [Run Specific CI Jobs](#run-specific-ci-jobs)
- [Merging your Change](#merging-your-change)
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
**Prerequisites**:
The following packages should be installed with `pip`:
- `expecttest` and `hypothesis` - required to run tests
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
- `mypy` - recommended for linting
- `pytest` - recommended to run tests more selectively
Running
```
@ -350,32 +350,15 @@ make lint
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
#### Running `pyrefly`
#### Running `mypy`
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
**Getting Started with Pyrefly:**
To run type checking on the PyTorch codebase:
```bash
pyrefly check
```
For more detailed error information with summaries:
```bash
pyrefly check --summarize-errors
```
**Learn More:**
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
`mypy` is an optional static type checker for Python. We have multiple `mypy`
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
See [Guide for adding type annotations to
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
for more information on how to set up `mypy` and tackle type annotation
tasks.
### C++ Unit Testing

View File

@ -37,7 +37,7 @@ Copyright (c) 2024 Tri Dao.
All rights reserved.
All contributions by Arm:
Copyright (c) 2021, 2023-2025 Arm Limited and/or its affiliates
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors

View File

@ -1,7 +1,7 @@
# Security Policy
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
- [**Using PyTorch Securely**](#using-pytorch-securely)
- [**Using Pytorch Securely**](#using-pytorch-securely)
- [Untrusted models](#untrusted-models)
- [TorchScript models](#torchscript-models)
- [Untrusted inputs](#untrusted-inputs)
@ -10,30 +10,28 @@
- [**CI/CD security principles**](#cicd-security-principles)
## Reporting Security Issues
Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch.
Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch.
However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem.
Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new
All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
**Note on crashes and out of bounds access**: PyTorch is a computational framework that performs operations on behalf of the caller. Like many low-level libraries, PyTorch generally does not validate all inputs to every function—the responsibility for providing valid arguments lies with the calling code. While crashes and out of bounds memory access should be reported as bugs, they are generally not considered security vulnerabilities in PyTorch's threat model.
All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework.
Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported:
https://www.facebook.com/whitehat
## Using PyTorch Securely
**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
## Using Pytorch Securely
**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package).
### Untrusted models
Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources].
**Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing).
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details.
Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs.
@ -45,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de
### TorchScript models
TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load.
### Untrusted inputs during training and prediction
@ -61,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some
### Data privacy
**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits).
**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and:
- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment)
- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits).
### Using distributed features

View File

@ -174,12 +174,6 @@ class TORCH_API Context {
static long versionCuDNN() {
return detail::getCUDAHooks().versionCuDNN();
}
static long versionRuntimeCuDNN() {
return detail::getCUDAHooks().versionRuntimeCuDNN();
}
static long versionCuDNNFrontend() {
return detail::getCUDAHooks().versionCuDNNFrontend();
}
static bool hasCuSOLVER() {
return detail::getCUDAHooks().hasCuSOLVER();
}

View File

@ -6,7 +6,6 @@
#include <c10/util/Half.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/complex.h>
#include <torch/headeronly/core/Dispatch.h>
#ifdef __CUDACC__
#include <cuda.h> // For CUDA_VERSION
@ -62,9 +61,12 @@ TORCH_API void record_kernel_function_dtype(std::string name);
} \
} while (0)
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
case enum_type: { \
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
return __VA_ARGS__(); \
}
#define AT_DISPATCH_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
@ -93,6 +95,14 @@ TORCH_API void record_kernel_function_dtype(std::string name);
return __VA_ARGS__(); \
}
namespace detail {
inline at::ScalarType scalar_type(at::ScalarType s) {
return s;
}
} // namespace detail
// The AT_DISPATCH_* family of macros provides the ability to
// conveniently generate specializations of a kernel over all of the
// dtypes we care about in PyTorch. We call it "dispatch" because
@ -180,13 +190,27 @@ TORCH_API void record_kernel_function_dtype(std::string name);
// but we're just being safe (and it doesn't hurt.) Note we must
// use it to shut up warnings about unused store.
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH_TMPL( \
RECORD_KERNEL_FUNCTION_DTYPE, \
TORCH_CHECK_NOT_IMPLEMENTED, \
TYPE, \
NAME, \
__VA_ARGS__)
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
constexpr const char* at_dispatch_name = NAME; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK_NOT_IMPLEMENTED( \
false, \
'"', \
at_dispatch_name, \
"\" not implemented for '", \
toString(_st), \
"'"); \
} \
C10_DIAGNOSTIC_POP() \
}()
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \

View File

@ -1,8 +1,3 @@
#pragma once
#include <torch/headeronly/core/Dispatch_v2.h>
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
#include <ATen/Dispatch.h>
// This is a new implementation of the AT_DISPATCH macro family from
@ -79,19 +74,41 @@
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
// relied on GPT4 to help me get it right.
// Public API macros
// See documentation above
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
THO_DISPATCH_V2_TMPL( \
AT_DISPATCH_SWITCH, \
AT_DISPATCH_CASE, \
TYPE, \
NAME, \
AT_WRAP(BODY), \
__VA_ARGS__)
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
// This macro lets you pass an arbitrary expression that may contain internal
// commas to another macro without having the commas causing the expression
// to be interpreted as being multiple arguments
#define AT_WRAP(...) __VA_ARGS__
#define AT_FLOAT8_TYPES \
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
#define AT_INTEGRAL_TYPES \
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
#define AT_INTEGRAL_TYPES_V2 \
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
// NB: not *actually* all types
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
#define AT_ALL_TYPES_AND_COMPLEX \
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
// Helper macros
// Unused helper macros, kept for BC:
#define AT_AP_VAR(N, T, ...) \
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
#define AT_CONCAT_AUX(a, b) a##b
#define AT_EXPAND(X) X
// Ensure we never have too many scalar types for the expansion here to
// support. To bump this, you must regenerate the macros below.
@ -102,6 +119,12 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
num_args = 60
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
for i in range(1, num_args+1):
args = ', '.join(f'_{i}' for i in range(1, i+1))
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
@ -112,6 +135,8 @@ for i in range(1, num_args+1):
// Begin generated code
// clang-format off
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)

View File

@ -226,8 +226,8 @@ template <
typename B = HostBlock<S>>
struct CachingHostAllocatorImpl {
virtual ~CachingHostAllocatorImpl() {
if (active_) {
active_ = false;
active_ = false;
if (pinned_use_background_threads()) {
getBackgroundThreadPool()->waitWorkComplete();
}
}
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
if (pinned_use_background_threads()) {
// Launch the background thread and process events in a loop.
static bool background_thread_flag [[maybe_unused]] = [this] {
active_ = true;
getBackgroundThreadPool()->run([&]() {
while (active_) {
process_events();
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
std::deque<std::pair<E, B*>> events_; // event queue paired with block
// Indicates whether the event-processing thread pool is active.
// Indicates whether the object is active.
// Set to false in the destructor to signal background threads to stop.
std::atomic<bool> active_{false};
std::atomic<bool> active_{true};
protected:
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
};

View File

@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
auto vals = svreinterpret_u16_bf16(values);
vals = sveor_u16_x(ptrue, vals, mask);
return svreinterpret_bf16_u16(vals);
}
};
Vectorized<BFloat16> round() const;
Vectorized<BFloat16> tan() const;
Vectorized<BFloat16> tanh() const;
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
return convert_float_bfloat16(v1, v2); \
}
DEFINE_BF16_FUNC_VIA_FLOAT(isnan)
DEFINE_BF16_FUNC_VIA_FLOAT(angle)
DEFINE_BF16_FUNC_VIA_FLOAT(acos)
DEFINE_BF16_FUNC_VIA_FLOAT(acosh)
DEFINE_BF16_FUNC_VIA_FLOAT(asin)
DEFINE_BF16_FUNC_VIA_FLOAT(atan)
DEFINE_BF16_FUNC_VIA_FLOAT(atanh)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign)
DEFINE_BF16_FUNC_VIA_FLOAT(erf)
DEFINE_BF16_FUNC_VIA_FLOAT(erfc)
DEFINE_BF16_FUNC_VIA_FLOAT(exp)
DEFINE_BF16_FUNC_VIA_FLOAT(exp2)
DEFINE_BF16_FUNC_VIA_FLOAT(expm1)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot)
DEFINE_BF16_FUNC_VIA_FLOAT(i0)
DEFINE_BF16_FUNC_VIA_FLOAT(i0e)
DEFINE_BF16_FUNC_VIA_FLOAT(digamma)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter)
DEFINE_BF16_FUNC_VIA_FLOAT(log)
DEFINE_BF16_FUNC_VIA_FLOAT(log2)
DEFINE_BF16_FUNC_VIA_FLOAT(log10)
DEFINE_BF16_FUNC_VIA_FLOAT(log1p)
DEFINE_BF16_FUNC_VIA_FLOAT(sin)
DEFINE_BF16_FUNC_VIA_FLOAT(sinh)
DEFINE_BF16_FUNC_VIA_FLOAT(cos)
DEFINE_BF16_FUNC_VIA_FLOAT(cosh)
DEFINE_BF16_FUNC_VIA_FLOAT(ceil)
DEFINE_BF16_FUNC_VIA_FLOAT(floor)
DEFINE_BF16_FUNC_VIA_FLOAT(round)
DEFINE_BF16_FUNC_VIA_FLOAT(tan)
DEFINE_BF16_FUNC_VIA_FLOAT(tanh)
DEFINE_BF16_FUNC_VIA_FLOAT(trunc)
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma)
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt)
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal)
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt)
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow)
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
DEFINE_BF16_FUNC_VIA_FLOAT(log);
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
DEFINE_BF16_FUNC_VIA_FLOAT(round);
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
const Vectorized<BFloat16>& other) const {

View File

@ -388,7 +388,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
uint32_t mask = -1;
#endif
void * alpha_ptr = &alpha;
void * beta_ptr = &beta;
@ -428,7 +427,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
if (fp16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
mask =
uint32_t mask =
fp16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -445,7 +444,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
if (bf16_reduction !=
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
mask =
uint32_t mask =
bf16_reduction ==
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
@ -512,41 +511,17 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
#ifndef USE_ROCM
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
#else
const bool lie_to_cublaslt = false;
#endif
if (lie_to_cublaslt) {
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
FakeBdesc.descriptor(),
FakeCdesc.descriptor(),
FakeCdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
} else {
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
}
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
ltHandle,
computeDesc.descriptor(),
Adesc.descriptor(),
Bdesc.descriptor(),
Cdesc.descriptor(),
Cdesc.descriptor(),
preference.descriptor(),
1,
&heuristicResult,
&returnedResult));
if (returnedResult == 0) {
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
}
@ -1597,7 +1572,7 @@ bool gemm_and_bias(
}
using opmath_t = at::opmath_type<Dtype>;
opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr
opmath_t beta_val = 0; // bias is added in epilogue
cudaDataType_t abType = CUDA_R_32F;
cudaDataType_t cType = CUDA_R_32F;
@ -1686,22 +1661,15 @@ bool gemm_and_bias(
_syncCurrentWithCarveoutStream(stream, true);
}
#endif
const auto epilogue = [&]() -> cublasLtEpilogue_t {
// The cuBLAS documentation indicates that
// *_<ACTIVATION>_BIAS = *_<ACTIVATION>,
// but we keep it verbose here for clarity.
switch (activation) {
case GEMMAndBiasActivationEpilogue::RELU:
return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU;
case GEMMAndBiasActivationEpilogue::GELU:
return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU;
default:
return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT;
}
}();
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
}
if (bias) {
if (bias != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
}

View File

@ -213,7 +213,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
scan_op,
num_items,
at::cuda::getCurrentCUDAStream());
C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_HIP_KERNEL_LAUNCH_CHECK();
#else
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
@ -471,7 +471,7 @@ inline void exclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT
init_value,
num_items,
at::cuda::getCurrentCUDAStream());
C10_CUDA_KERNEL_LAUNCH_CHECK();
C10_HIP_KERNEL_LAUNCH_CHECK();
#else
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,

View File

@ -24,13 +24,7 @@ namespace detail {
// radix_sort_pairs doesn't interact with value_t other than to copy
// the data, so we can save template instantiations by reinterpreting
// it as an opaque type.
// We use native integer types for 1/2/4/8-byte values to reduce
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
template <int N> struct alignas(N) OpaqueType { char data[N]; };
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
template<typename key_t, int value_size>
void radix_sort_pairs_impl(

View File

@ -21,7 +21,6 @@
#if AT_CUDNN_ENABLED()
#include <ATen/cudnn/cudnn-wrapper.h>
#include <cudnn_frontend.h>
#endif
#if AT_MAGMA_ENABLED()
@ -352,26 +351,6 @@ long CUDAHooks::versionCuDNN() const {
#endif
}
long CUDAHooks::versionRuntimeCuDNN() const {
#if AT_CUDNN_ENABLED()
#ifndef USE_STATIC_CUDNN
return cudnnGetVersion();
#else
return CUDNN_VERSION;
#endif
#else
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionCuDNNFrontend() const {
#if AT_CUDNN_ENABLED()
return CUDNN_FRONTEND_VERSION;
#else
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionMIOpen() const {
#if AT_ROCM_ENABLED()
return MIOPEN_VERSION_MAJOR * 10000 +

View File

@ -49,8 +49,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCUDART() const override;
long versionCUDART() const override;
long versionCuDNN() const override;
long versionRuntimeCuDNN() const override;
long versionCuDNNFrontend() const override;
long versionMIOpen() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;

View File

@ -174,14 +174,6 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionRuntimeCuDNN() const {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionCuDNNFrontend() const {
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionMIOpen() const {
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
}

View File

@ -157,8 +157,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
DispatchKey::Negative,
DispatchKey::Conjugate,
DispatchKey::XLA,
DispatchKey::XPU,
DispatchKey::HPU,
DispatchKey::CUDA,
DispatchKey::CPU,
DispatchKey::PrivateUse1,

View File

@ -0,0 +1,239 @@
#pragma once
#include <c10/hip/HIPCachingAllocator.h>
// Use of c10::hip namespace here makes hipification easier, because
// I don't have to also fix namespaces. Sorry!
namespace c10::hip {
// Takes a valid HIPAllocator (of any sort) and turns it into
// an allocator pretending to be a CUDA allocator. See
// Note [Masquerading as CUDA]
class HIPAllocatorMasqueradingAsCUDA final : public HIPCachingAllocator::HIPAllocator {
HIPCachingAllocator::HIPAllocator* allocator_;
public:
explicit HIPAllocatorMasqueradingAsCUDA(HIPCachingAllocator::HIPAllocator* allocator)
: allocator_(allocator) {}
virtual ~HIPAllocatorMasqueradingAsCUDA() = default;
// From c10::Allocator
DataPtr allocate(size_t size) override {
DataPtr r = allocator_->allocate(size);
r.unsafe_set_device(Device(c10::DeviceType::CUDA, r.device().index()));
return r;
}
bool is_simple_data_ptr(const DataPtr& data_ptr) const override {
return allocator_->is_simple_data_ptr(data_ptr);
}
DeleterFnPtr raw_deleter() const override {
return allocator_->raw_deleter();
}
void copy_data(void* dest, const void* src, std::size_t count) const final {
allocator_->copy_data(dest, src, count);
}
// From DeviceAllocator
bool initialized() override {
return allocator_->initialized();
}
void emptyCache(MempoolId_t mempool_id = {0, 0}) override {
allocator_->emptyCache(mempool_id);
}
void recordStream(const DataPtr& ptr, c10::Stream stream) override {
HIPStream hip_stream = HIPStream(stream);
recordStream(ptr, hip_stream);
}
CachingDeviceAllocator::DeviceStats getDeviceStats(c10::DeviceIndex device) override {
return allocator_->getDeviceStats(device);
}
void resetAccumulatedStats(c10::DeviceIndex device) override {
allocator_->resetAccumulatedStats(device);
}
void resetPeakStats(c10::DeviceIndex device) override {
allocator_->resetPeakStats(device);
}
// From CUDAAllocator
void* raw_alloc(size_t nbytes) override {
return allocator_->raw_alloc(nbytes);
}
void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) override {
return allocator_->raw_alloc_with_stream(nbytes, stream);
}
void raw_delete(void* ptr) override {
allocator_->raw_delete(ptr);
}
void init(int device_count) override {
allocator_->init(device_count);
}
double getMemoryFraction(c10::DeviceIndex device) override {
return allocator_->getMemoryFraction(device);
}
void setMemoryFraction(double fraction, c10::DeviceIndex device) override {
allocator_->setMemoryFraction(fraction, device);
}
std::vector<HIPCachingAllocator::StreamSegmentSize> getExpandableSegmentSizes(c10::DeviceIndex device) override {
return allocator_->getExpandableSegmentSizes(device);
}
void enable(bool value) override {
allocator_->enable(value);
}
bool isEnabled() const override {
return allocator_->isEnabled();
}
void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) override {
allocator_->cacheInfo(device, largestBlock);
}
void* getBaseAllocation(void* ptr, size_t* size) override {
return allocator_->getBaseAllocation(ptr, size);
}
void recordStream(const DataPtr& ptr, HIPStream stream) override {
allocator_->recordStream(ptr, stream);
}
HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) override {
return allocator_->snapshot(mempool_id);
}
void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(hipStream_t)> filter) override {
allocator_->beginAllocateToPool(device, mempool_id, filter);
}
void endAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id) override {
allocator_->endAllocateToPool(device, mempool_id);
}
void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) override {
allocator_->releasePool(device, mempool_id);
}
int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) override {
return allocator_->getPoolUseCount(device, mempool_id);
}
void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
HIPAllocator* allocator = nullptr) override {
allocator_->createOrIncrefPool(device, mempool_id, allocator);
}
void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) override {
allocator_->setUseOnOOM(device, mempool_id);
}
bool checkPoolLiveAllocations(
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) override {
return allocator_->checkPoolLiveAllocations(device, mempool_id, expected_live_allocations);
}
HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) override {
return allocator_->shareIpcHandle(ptr);
}
std::shared_ptr<void> getIpcDevPtr(std::string handle) override {
return allocator_->getIpcDevPtr(handle);
}
bool isHistoryEnabled() override {
return allocator_->isHistoryEnabled();
}
void recordHistory(
bool enabled,
HIPCachingAllocator::CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
HIPCachingAllocator::RecordContext when,
bool clearHistory) override {
allocator_->recordHistory(enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
}
void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& md) override {
allocator_->recordAnnotation(md);
}
void pushCompileContext(std::string& md) override {
allocator_->pushCompileContext(md);
}
void popCompileContext() override {
allocator_->popCompileContext();
}
void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) override {
allocator_->attachOutOfMemoryObserver(observer);
}
void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) override {
allocator_->attachAllocatorTraceTracker(tracker);
}
void enablePeerAccess(c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) override {
allocator_->enablePeerAccess(dev, dev_to_access);
}
hipError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
hipStream_t stream,
bool p2p_enabled) override {
return allocator_->memcpyAsync(dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
}
std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) override {
return allocator_->getCheckpointState(device, id);
}
HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) override {
auto cpd = allocator_->setCheckpointPoolState(device, pps);
for (auto& ptr : cpd.dataptrs_allocd) {
ptr.unsafe_set_device(Device(c10::DeviceType::CUDA, ptr.device().index()));
}
return cpd;
}
std::string name() override {
return allocator_->name();
}
};
} // namespace c10::hip

View File

@ -0,0 +1,18 @@
#include <c10/hip/HIPCachingAllocator.h>
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
namespace c10 { namespace hip {
namespace HIPCachingAllocatorMasqueradingAsCUDA {
HIPCachingAllocator::HIPAllocator* get() {
static HIPAllocatorMasqueradingAsCUDA allocator(HIPCachingAllocator::get());
return &allocator;
}
void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream) {
HIPCachingAllocator::recordStream(ptr, stream.hip_stream());
}
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
}} // namespace c10::hip

View File

@ -0,0 +1,194 @@
#pragma once
#include <c10/hip/HIPCachingAllocator.h>
#include <ATen/hip/impl/HIPAllocatorMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
namespace c10 {
// forward declaration
class DataPtr;
namespace hip {
namespace HIPCachingAllocatorMasqueradingAsCUDA {
C10_HIP_API HIPCachingAllocator::HIPAllocator* get();
C10_HIP_API void recordStreamMasqueradingAsCUDA(const DataPtr& ptr, HIPStreamMasqueradingAsCUDA stream);
inline void* raw_alloc(size_t nbytes) {
return get()->raw_alloc(nbytes);
}
inline void* raw_alloc_with_stream(size_t nbytes, hipStream_t stream) {
return get()->raw_alloc_with_stream(nbytes, stream);
}
inline void raw_delete(void* ptr) {
return get()->raw_delete(ptr);
}
inline void init(int device_count) {
return get()->init(device_count);
}
inline double getMemoryFraction(c10::DeviceIndex device) {
return get()->getMemoryFraction(device);
}
inline void setMemoryFraction(double fraction, c10::DeviceIndex device) {
return get()->setMemoryFraction(fraction, device);
}
inline void emptyCache(MempoolId_t mempool_id = {0, 0}) {
return get()->emptyCache(mempool_id);
}
inline void enable(bool value) {
return get()->enable(value);
}
inline bool isEnabled() {
return get()->isEnabled();
}
inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) {
return get()->cacheInfo(device, largestBlock);
}
inline void* getBaseAllocation(void* ptr, size_t* size) {
return get()->getBaseAllocation(ptr, size);
}
inline c10::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) {
return get()->getDeviceStats(device);
}
inline void resetAccumulatedStats(c10::DeviceIndex device) {
return get()->resetAccumulatedStats(device);
}
inline void resetPeakStats(c10::DeviceIndex device) {
return get()->resetPeakStats(device);
}
inline HIPCachingAllocator::SnapshotInfo snapshot(MempoolId_t mempool_id = {0, 0}) {
return get()->snapshot(mempool_id);
}
inline std::shared_ptr<HIPCachingAllocator::AllocatorState> getCheckpointState(
c10::DeviceIndex device,
MempoolId_t id) {
return get()->getCheckpointState(device, id);
}
inline HIPCachingAllocator::CheckpointDelta setCheckpointPoolState(
c10::DeviceIndex device,
std::shared_ptr<HIPCachingAllocator::AllocatorState> pps) {
return get()->setCheckpointPoolState(device, std::move(pps));
}
inline void beginAllocateToPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
std::function<bool(hipStream_t)> filter) {
get()->beginAllocateToPool(device, mempool_id, std::move(filter));
}
inline void endAllocateToPool(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->endAllocateToPool(device, mempool_id);
}
inline void recordHistory(
bool enabled,
HIPCachingAllocator::CreateContextFn context_recorder,
size_t alloc_trace_max_entries,
HIPCachingAllocator::RecordContext when,
bool clearHistory) {
return get()->recordHistory(
enabled, context_recorder, alloc_trace_max_entries, when, clearHistory);
}
inline void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& md) {
return get()->recordAnnotation(md);
}
inline void pushCompileContext(std::string& md) {
return get()->pushCompileContext(md);
}
inline void popCompileContext() {
return get()->popCompileContext();
}
inline bool isHistoryEnabled() {
return get()->isHistoryEnabled();
}
inline bool checkPoolLiveAllocations(
c10::DeviceIndex device,
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
return get()->checkPoolLiveAllocations(
device, mempool_id, expected_live_allocations);
}
inline void attachOutOfMemoryObserver(HIPCachingAllocator::OutOfMemoryObserver observer) {
return get()->attachOutOfMemoryObserver(std::move(observer));
}
inline void attachAllocatorTraceTracker(HIPCachingAllocator::AllocatorTraceTracker tracker) {
return get()->attachAllocatorTraceTracker(std::move(tracker));
}
inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->releasePool(device, mempool_id);
}
inline void createOrIncrefPool(
c10::DeviceIndex device,
MempoolId_t mempool_id,
HIPCachingAllocator::HIPAllocator* allocator_ptr = nullptr) {
get()->createOrIncrefPool(device, mempool_id, allocator_ptr);
}
inline void setUseOnOOM(c10::DeviceIndex device, MempoolId_t mempool_id) {
get()->setUseOnOOM(device, mempool_id);
}
inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) {
return get()->getPoolUseCount(device, mempool_id);
}
inline std::shared_ptr<void> getIpcDevPtr(std::string handle) {
return get()->getIpcDevPtr(std::move(handle));
}
inline HIPCachingAllocator::ShareableHandle shareIpcHandle(void* ptr) {
return get()->shareIpcHandle(ptr);
}
inline std::string name() {
return get()->name();
}
inline hipError_t memcpyAsync(
void* dst,
int dstDevice,
const void* src,
int srcDevice,
size_t count,
hipStream_t stream,
bool p2p_enabled) {
return get()->memcpyAsync(
dst, dstDevice, src, srcDevice, count, stream, p2p_enabled);
}
inline void enablePeerAccess(
c10::DeviceIndex dev,
c10::DeviceIndex dev_to_access) {
return get()->enablePeerAccess(dev, dev_to_access);
}
} // namespace HIPCachingAllocatorMasqueradingAsCUDA
} // namespace hip
} // namespace c10

View File

@ -0,0 +1,14 @@
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
// THIS IS A MASSIVE HACK. This will BREAK you Caffe2 CUDA code if you
// load ATen_hip, even if you don't ever actually use ATen_hip at runtime.
//
// If you ever link ATen_hip statically into the full library along
// with ATen_cuda (libomnibus), the loading order of this versus the regular
// ATen_cuda will be nondeterministic, and you'll nondeterministically get
// one or the other. (This will be obvious because all of your code
// will fail.)
//
// This hack can be removed once PyTorch is out-of-place HIPified, and
// doesn't pretend CUDA is HIP.
C10_REGISTER_GUARD_IMPL(CUDA, at::cuda::HIPGuardImplMasqueradingAsCUDA)

View File

@ -0,0 +1,383 @@
#pragma once
#include <ATen/hip/HIPConfig.h>
// The includes of HIPGuard.h
#include <c10/hip/impl/HIPGuardImpl.h>
#include <c10/hip/HIPMacros.h>
#include <c10/core/DeviceType.h>
#include <c10/core/impl/InlineDeviceGuard.h>
#include <c10/core/impl/InlineStreamGuard.h>
#include <c10/util/Exception.h>
#include <c10/hip/impl/HIPGuardImpl.h>
#include <ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h>
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
// Use of c10::hip namespace here makes hipification easier, because
// I don't have to also fix namespaces. Sorry!
namespace c10 { namespace hip {
// Note [Masquerading as CUDA]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
// c10_hip is very easy to understand: it is HIPified from c10_cuda,
// and anywhere you said CUDA, the source code now says HIP. HIPified
// PyTorch is much harder to understand: it is HIPified from regular
// PyTorch, yes, but NO source-to-source translation from CUDA to
// HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP".
// For example, when you use HIPified PyTorch, you say x.cuda() to
// move a tensor onto ROCm device. We call this situation "HIP
// masquerading as CUDA".
//
// This leads to a very awkward situation when we want to call c10_hip
// code from PyTorch, since c10_hip is expecting things to be called
// HIP, but PyTorch is calling them CUDA (masquerading as HIP). To
// fix this impedance mismatch, we have MasqueradingAsCUDA variants
// for all c10_hip classes. These translate between the "HIP" and "CUDA
// masquerading as HIP" worlds. For example,
// HIPGuardImplMasqueradingAsCUDA (this file) provides something like a
// HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type()
// returns CUDA, getDevice() reports the current HIP device as a CUDA
// device.)
//
// We should be able to delete all of these classes entirely once
// we switch PyTorch to calling a HIP a HIP.
//
// When you add a new MasqueradingAsCUDA class/function, you need to
// also update the rewrite rules in torch/utils/hipify/cuda_to_hip_mappings.py
//
//
//
// By the way, note that the cpp file associated with this also
// *overwrites* the entry in the DeviceGuardImpl registry for CUDA with
// this HIP implementation.
struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::CUDA;
HIPGuardImplMasqueradingAsCUDA() {}
HIPGuardImplMasqueradingAsCUDA(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == c10::DeviceType::CUDA);
}
c10::DeviceType type() const override {
return c10::DeviceType::CUDA;
}
Device exchangeDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
Device old_device = getDevice();
if (old_device.index() != d.index()) {
C10_HIP_CHECK(hipSetDevice(d.index()));
}
return old_device;
}
Device getDevice() const override {
int device;
C10_HIP_CHECK(hipGetDevice(&device));
return Device(c10::DeviceType::CUDA, device);
}
void setDevice(Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_cuda());
C10_HIP_CHECK(hipSetDevice(d.index()));
}
void uncheckedSetDevice(Device d) const noexcept override {
C10_HIP_CHECK_WARN(hipSetDevice(d.index()));
}
Stream getStream(Device d) const override {
return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap();
}
Stream getDefaultStream(Device d) const override {
return getDefaultHIPStreamMasqueradingAsCUDA(d.index());
}
Stream getNewStream(Device d, int priority = 0) const override {
return getStreamFromPoolMasqueradingAsCUDA(priority, d.index());
}
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) const override {
return getStreamFromPoolMasqueradingAsCUDA(isHighPriority, d.index());
}
Stream exchangeStream(Stream s) const override {
HIPStreamMasqueradingAsCUDA cs(s);
auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index());
setCurrentHIPStreamMasqueradingAsCUDA(cs);
return old_stream.unwrap();
}
DeviceIndex deviceCount() const noexcept override {
int deviceCnt;
hipError_t _err;
_err = hipGetDeviceCount(&deviceCnt);
if(_err != hipErrorNoDevice && _err != hipSuccess)
C10_HIP_CHECK(_err);
return deviceCnt;
}
// Event-related functions
// Note: hipEventCreateWithFlags should be called on the same device as
// the recording stream's device.
void createEvent(
hipEvent_t* hip_event,
const EventFlag flag) const {
// Maps PyTorch's Event::Flag to HIP flag
auto hip_flag = hipEventDefault;
switch (flag) {
case EventFlag::PYTORCH_DEFAULT:
hip_flag = hipEventDisableTiming;
break;
case EventFlag::BACKEND_DEFAULT:
hip_flag = hipEventDefault;
break;
default:
TORCH_CHECK(false, "HIP event received unknown flag");
}
C10_HIP_CHECK(hipEventCreateWithFlags(hip_event, hip_flag));
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override {
if (!event) return;
auto hip_event = static_cast<hipEvent_t>(event);
int orig_device;
C10_HIP_CHECK_WARN(hipGetDevice(&orig_device));
C10_HIP_CHECK_WARN(hipSetDevice(device_index));
C10_HIP_CHECK_WARN(hipEventDestroy(hip_event));
C10_HIP_CHECK_WARN(hipSetDevice(orig_device));
}
void record(void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
"Event device index ",
device_index,
" does not match recording stream's device index ",
stream.device_index(),
".");
hipEvent_t hip_event = static_cast<hipEvent_t>(*event);
HIPStreamMasqueradingAsCUDA hip_stream{stream};
// Moves to stream's device to record
const auto orig_device = getDevice();
setDevice(stream.device());
// Creates the event (lazily)
if (!hip_event) createEvent(&hip_event, flag);
C10_HIP_CHECK(hipEventRecord(hip_event, hip_stream));
// Makes the void* point to the (possibly just allocated) HIP event
*event = hip_event;
// Resets device
setDevice(orig_device);
}
void block(
void* event,
const Stream& stream) const override {
if (!event) return;
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
HIPStreamMasqueradingAsCUDA hip_stream{stream};
const auto orig_device = getDevice();
setDevice(stream.device());
C10_HIP_CHECK(hipStreamWaitEvent(
hip_stream,
hip_event,
/*flags (must be zero)=*/ 0));
setDevice(orig_device);
}
bool queryEvent(void* event) const override {
if (!event) return true;
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
const hipError_t err = hipEventQuery(hip_event);
if (err != hipErrorNotReady) C10_HIP_CHECK(err);
else {
// ignore and clear the error if not ready
(void)hipGetLastError();
}
return (err == hipSuccess);
}
// Stream-related functions
bool queryStream(const Stream& stream) const override {
HIPStreamMasqueradingAsCUDA hip_stream{stream};
return hip_stream.query();
}
void synchronizeStream(const Stream& stream) const override {
HIPStreamMasqueradingAsCUDA hip_stream{stream};
hip_stream.synchronize();
}
void synchronizeEvent(void* event) const override {
if (!event)
return;
hipEvent_t hip_event = static_cast<hipEvent_t>(event);
C10_HIP_CHECK(hipEventSynchronize(hip_event));
}
// Note: synchronizeDevice can be safely called from any device
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
int orig_device{-1};
C10_HIP_CHECK(hipGetDevice(&orig_device));
C10_HIP_CHECK(hipSetDevice(device_index));
C10_HIP_CHECK(hipDeviceSynchronize());
C10_HIP_CHECK(hipSetDevice(orig_device));
}
void recordDataPtrOnStream(
const c10::DataPtr& data_ptr,
const Stream& stream) const override {
HIPStreamMasqueradingAsCUDA hip_stream{stream};
HIPCachingAllocatorMasqueradingAsCUDA::recordStreamMasqueradingAsCUDA(data_ptr, hip_stream);
}
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
const override {
TORCH_CHECK(
event1 && event2,
"Both events must be recorded before calculating elapsed time.");
int orig_device;
C10_HIP_CHECK(hipGetDevice(&orig_device));
C10_HIP_CHECK(hipSetDevice(device_index));
hipEvent_t hip_event1 = static_cast<hipEvent_t>(event1);
hipEvent_t hip_event2 = static_cast<hipEvent_t>(event2);
float time_ms = 0;
// raise hipErrorNotReady if either event is recorded but not yet completed
C10_HIP_CHECK(hipEventElapsedTime(&time_ms, hip_event1, hip_event2));
C10_HIP_CHECK(hipSetDevice(orig_device));
return static_cast<double>(time_ms);
}
};
// All of the guards which have HIPGuardImpl burned in need to also have
// variants using HIPGuardImplMasqueradingAsCUDA.
/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
/// the correct InlineDeviceGuard burned in. Sorry about the
/// copy-pasting.
struct HIPGuardMasqueradingAsCUDA {
explicit HIPGuardMasqueradingAsCUDA() = delete;
explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
void set_device(Device device) { guard_.set_device(device); }
void reset_device(Device device) { guard_.reset_device(device); }
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
Device original_device() const { return guard_.original_device(); }
Device current_device() const { return guard_.current_device(); }
private:
c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
};
struct OptionalHIPGuardMasqueradingAsCUDA {
explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<Device> device_opt) : guard_(device_opt) {}
explicit OptionalHIPGuardMasqueradingAsCUDA(std::optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
void set_device(Device device) { guard_.set_device(device); }
void reset_device(Device device) { guard_.reset_device(device); }
void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
std::optional<Device> original_device() const { return guard_.original_device(); }
std::optional<Device> current_device() const { return guard_.current_device(); }
void reset() { guard_.reset(); }
private:
c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
};
struct HIPStreamGuardMasqueradingAsCUDA {
explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
HIPStreamMasqueradingAsCUDA original_stream() const {
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream());
}
HIPStreamMasqueradingAsCUDA current_stream() const {
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream());
}
Device current_device() const { return guard_.current_device(); }
Device original_device() const { return guard_.original_device(); }
private:
c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
};
struct OptionalHIPStreamGuardMasqueradingAsCUDA {
explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
explicit OptionalHIPStreamGuardMasqueradingAsCUDA(std::optional<Stream> stream_opt) : guard_(stream_opt) {}
OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
void reset_stream(Stream stream) { guard_.reset_stream(stream); }
std::optional<HIPStreamMasqueradingAsCUDA> original_stream() const {
auto r = guard_.original_stream();
if (r.has_value()) {
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
} else {
return std::nullopt;
}
}
std::optional<HIPStreamMasqueradingAsCUDA> current_stream() const {
auto r = guard_.current_stream();
if (r.has_value()) {
return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value());
} else {
return std::nullopt;
}
}
void reset() { guard_.reset(); }
private:
c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
};
struct HIPMultiStreamGuardMasqueradingAsCUDA {
explicit HIPMultiStreamGuardMasqueradingAsCUDA(ArrayRef<HIPStreamMasqueradingAsCUDA> streams)
: guard_(unwrapStreams(streams)) {}
HIPMultiStreamGuardMasqueradingAsCUDA(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
HIPMultiStreamGuardMasqueradingAsCUDA& operator=(const HIPMultiStreamGuardMasqueradingAsCUDA&) = delete;
HIPMultiStreamGuardMasqueradingAsCUDA(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
HIPMultiStreamGuardMasqueradingAsCUDA& operator=(HIPMultiStreamGuardMasqueradingAsCUDA&& other) = delete;
private:
c10::impl::InlineMultiStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
static std::vector<Stream> unwrapStreams(ArrayRef<HIPStreamMasqueradingAsCUDA> hipStreams) {
std::vector<Stream> streams;
streams.reserve(hipStreams.size());
for (const HIPStreamMasqueradingAsCUDA& hipStream : hipStreams) {
streams.push_back(hipStream);
}
return streams;
}
};
}} // namespace c10::hip

View File

@ -0,0 +1,135 @@
#pragma once
#include <c10/hip/HIPStream.h>
// Use of c10::hip namespace here makes hipification easier, because
// I don't have to also fix namespaces. Sorry!
namespace c10 { namespace hip {
// See Note [Masquerading as CUDA] for motivation
class HIPStreamMasqueradingAsCUDA {
public:
enum Unchecked { UNCHECKED };
explicit HIPStreamMasqueradingAsCUDA(Stream stream)
: HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) {
// We did the coercion unchecked; check that it was right.
TORCH_CHECK(stream.device().is_cuda() /* !!! */);
}
explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream)
// Unsafely coerce the "CUDA" stream into a HIP stream
: stream_(
HIPStream(
Stream(
Stream::UNSAFE,
Device(c10::DeviceType::HIP, stream.device_index()),
stream.id())
)
) {}
// New constructor, just for this. Does NOT coerce.
explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {}
bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
return stream_ == other.stream_;
}
bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept {
return stream_ != other.stream_;
}
operator hipStream_t() const { return stream_.stream(); }
operator Stream() const {
// Unsafely coerce HIP stream into a "CUDA" stream
return Stream(Stream::UNSAFE, device(), id());
}
DeviceIndex device_index() const { return stream_.device_index(); }
// Unsafely coerce HIP device into CUDA device
c10::DeviceType device_type() const { return c10::DeviceType::CUDA; }
Device device() const {
// Unsafely coerce HIP device into CUDA device
return Device(c10::DeviceType::CUDA, stream_.device_index());
}
StreamId id() const { return stream_.id(); }
bool query() const { return stream_.query(); }
void synchronize() const { stream_.synchronize(); }
int priority() const { return stream_.priority(); }
hipStream_t stream() const { return stream_.stream(); }
Stream unwrap() const {
// Unsafely coerce HIP stream into "CUDA" stream
return Stream(Stream::UNSAFE, device(), id());
}
c10::StreamData3 pack3() const noexcept {
// Unsafely coerce HIP stream into "CUDA" stream before packing
return unwrap().pack3();
}
static HIPStreamMasqueradingAsCUDA unpack3(StreamId stream_id,
DeviceIndex device_index,
c10::DeviceType device_type) {
// NB: constructor manages CUDA->HIP translation for us
return HIPStreamMasqueradingAsCUDA(Stream::unpack3(
stream_id, device_index, device_type));
}
static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); }
// New method, gets the underlying HIPStream
HIPStream hip_stream() const { return stream_; }
private:
HIPStream stream_;
};
HIPStreamMasqueradingAsCUDA
inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) {
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device));
}
HIPStreamMasqueradingAsCUDA
inline getStreamFromPoolMasqueradingAsCUDA(const int priority, DeviceIndex device = -1) {
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(priority, device));
}
HIPStreamMasqueradingAsCUDA
inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) {
return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device));
}
inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index));
}
inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) {
return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index));
}
inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) {
setCurrentHIPStream(stream.hip_stream());
}
inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) {
stream << s.hip_stream() << " (masquerading as CUDA)";
return stream;
}
}} // namespace c10::hip
namespace std {
template <>
struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> {
size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept {
return std::hash<c10::Stream>{}(s.unwrap());
}
};
} // namespace std

View File

@ -39,7 +39,7 @@ using MIOpenPoolType = at::cuda::DeviceThreadHandlePool<
miopenHandle_t getMiopenHandle() {
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(at::cuda::GetDevice(&device));
AT_CUDA_CHECK(c10::hip::GetDevice(&device));
// Thread local PoolWindows are lazily-initialized
// to avoid initialization issues that caused hangs on Windows.
@ -51,7 +51,7 @@ miopenHandle_t getMiopenHandle() {
pool->newPoolWindow());
auto handle = myPoolWindow->reserve(device);
MIOPEN_CHECK(miopenSetStream(handle, at::cuda::getCurrentCUDAStream()));
MIOPEN_CHECK(miopenSetStream(handle, c10::hip::getCurrentHIPStream()));
return handle;
}

View File

@ -440,7 +440,7 @@ bool MPSHeapAllocatorImpl::release_cached_buffers() {
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
m_mutex.unlock();
auto stream = getDefaultMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
dispatch_sync(stream->queue(), ^() {
stream->synchronize(SyncType::COMMIT_AND_WAIT);
});
m_mutex.lock();

View File

@ -110,9 +110,6 @@ class TORCH_API MPSStream {
return _stream;
}
MTLBuffer_t getErrorBuffer();
void checkLastError();
private:
Stream _stream;
MTLCommandQueue_t _commandQueue = nil;
@ -124,8 +121,6 @@ class TORCH_API MPSStream {
dispatch_queue_t _serialQueue = nullptr;
// CommitAndContinue is enabled by default
bool _enableCommitAndContinue = true;
// Buffer that contains last raised error
MTLBuffer_t _errorBuffer = nil;
// use synchronize() to access any of these commit functions outside MPSStream
void commit();
@ -160,7 +155,4 @@ class TORCH_API MPSStreamImpl {
MPSStreamImpl();
};
#ifdef __OBJC__
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
#endif
} // namespace at::mps

View File

@ -3,13 +3,13 @@
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSProfiler.h>
#include <ATen/mps/MPSStream.h>
#include <c10/metal/error.h>
@interface MPSGraphExecutionDescriptor ()
@property(readwrite, atomic) BOOL enableCommitAndContinue;
@end
namespace at::mps {
//-----------------------------------------------------------------
// MPSStream
//-----------------------------------------------------------------
@ -30,10 +30,6 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
// Choose level which optimizes for GPU
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
_errorBuffer = [MPSDevice::getInstance()->device() newBufferWithLength:sizeof(c10::metal::ErrorMessages)
options:MTLResourceStorageModeShared];
std::memset([_errorBuffer contents], 0, 1024);
}
MPSStream::~MPSStream() {
@ -42,8 +38,6 @@ MPSStream::~MPSStream() {
[_executionDescriptor release];
[_compilationDescriptor release];
_executionDescriptor = nil;
[_errorBuffer release];
_errorBuffer = nil;
_compilationDescriptor = nil;
assert(_commandBuffer == nil);
@ -110,7 +104,6 @@ void MPSStream::commitAndWait() {
[_prevCommandBuffer waitUntilCompleted];
[_prevCommandBuffer release];
_prevCommandBuffer = nil;
checkLastError();
}
if (_commandBuffer) {
@ -118,7 +111,6 @@ void MPSStream::commitAndWait() {
[_commandBuffer waitUntilCompleted];
[_commandBuffer release];
_commandBuffer = nil;
checkLastError();
}
}
@ -161,7 +153,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
if (length == 0) {
return;
}
dispatch_sync_with_rethrow(_serialQueue, ^() {
dispatch_sync(_serialQueue, ^() {
@autoreleasepool {
endKernelCoalescing();
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
@ -191,7 +183,7 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer,
size_t dstOffset,
uint64_t profileId,
SyncType syncType) {
dispatch_sync_with_rethrow(_serialQueue, ^() {
dispatch_sync(_serialQueue, ^() {
@autoreleasepool {
endKernelCoalescing();
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
@ -244,7 +236,7 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
auto& profiler = getMPSProfiler();
const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled();
dispatch_sync_with_rethrow(_serialQueue, ^() {
dispatch_sync(_serialQueue, ^() {
endKernelCoalescing();
if (isGraphProfilingEnabled) {
// this function call is only relevant for interval-based Signposts
@ -274,24 +266,6 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
});
}
id<MTLBuffer> MPSStream::getErrorBuffer() {
return _errorBuffer;
}
void MPSStream::checkLastError() {
auto msgs = reinterpret_cast<c10::metal::ErrorMessages*>([_errorBuffer contents]);
const auto& msg = msgs->msg[0];
if (!msgs) {
return;
}
unsigned int count = 0;
std::swap(count, msgs->count);
if (!count) {
return;
}
throw c10::AcceleratorError({msg.func, msg.file, msg.line}, 1, msg.message);
}
//-----------------------------------------------------------------
// MPSStreamImpl
//-----------------------------------------------------------------
@ -315,19 +289,4 @@ MPSStream* getDefaultMPSStream() {
return MPSStreamImpl::getInstance();
}
// Helper methods
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
__block std::optional<std::exception_ptr> block_exception;
dispatch_sync(queue, ^() {
try {
block();
} catch (...) {
block_exception = std::current_exception();
}
});
if (block_exception) {
std::rethrow_exception(*block_exception);
}
}
} // namespace at::mps

View File

@ -1009,25 +1009,12 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) {
}
}
static Tensor send_to_meta(const Tensor& self, const Device& device) {
Tensor out_meta;
if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) {
out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device));
out_meta.unsafeGetTensorImpl()->set_wrapped_number(true);
} else {
out_meta = self.to(device);
}
return out_meta;
}
Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
auto out_device = correct_out_device(self, other);
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta);
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
}
@ -1036,9 +1023,7 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) {
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta);
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
if (self._is_zerotensor()) {
if (other._is_zerotensor()) {
@ -1067,9 +1052,8 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
auto device_ = Device(DeviceType::Meta);
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
auto self_meta = send_to_meta(self, device_);
auto other_meta = send_to_meta(other, device_);
auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha);
auto meta_out = at::_ops::add_Tensor::redispatch(
meta_dks, self.to(device_), other.to(device_), alpha);
auto get_out_like = [&] (const Tensor& tensor)
{

View File

@ -409,7 +409,7 @@ struct ConvParams {
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
return false;
}
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
// broken on cuDNN 9.8 - 9.14
if (cudnn_version >= 90800 && cudnn_version < 91500) {
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
@ -453,7 +453,7 @@ struct ConvParams {
}
// native kernel doesn't support 64-bit non-splittable case
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
if (cudnn_version < 0 || cudnn_version > 91000) {

View File

@ -5,13 +5,9 @@
#include <c10/macros/Macros.h>
#include <c10/util/MathConstants.h>
// ROCm hip compiler doesn't work well with using std:: in kernel functions
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
// ROCM hcc doesn't work well with using std:: in kernel functions
#if defined(__CUDA_ARCH__)
#include <c10/cuda/CUDAMathCompat.h>
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#endif
#define compat_exp c10::cuda::compat::exp
#define compat_ceil c10::cuda::compat::ceil
#define compat_floor c10::cuda::compat::floor
@ -21,6 +17,17 @@
#define compat_tan c10::cuda::compat::tan
#define compat_abs c10::cuda::compat::abs
#define compat_log1p c10::cuda::compat::log1p
#elif defined(__HIPCC__)
#include <c10/hip/HIPMathCompat.h>
#define compat_exp c10::hip::compat::exp
#define compat_ceil c10::hip::compat::ceil
#define compat_floor c10::hip::compat::floor
#define compat_log c10::hip::compat::log
#define compat_pow c10::hip::compat::pow
#define compat_sqrt c10::hip::compat::sqrt
#define compat_tan c10::hip::compat::tan
#define compat_abs c10::hip::compat::abs
#define compat_log1p c10::hip::compat::log1p
#else
#define compat_exp std::exp
#define compat_ceil std::ceil

Some files were not shown because too many files have changed in this diff Show More