mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Compare commits
145 Commits
ciflow/ind
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c5043aae8 | |||
| 3cfbf98ea9 | |||
| 47db55258b | |||
| 50af6f3393 | |||
| e545ba2d34 | |||
| a058bbdd6f | |||
| 2c78080ec0 | |||
| fe6615e397 | |||
| abf31db2cc | |||
| a4c7856112 | |||
| afb014541b | |||
| b91a2ab892 | |||
| 14a845a4ec | |||
| 5135ace3a3 | |||
| e7c1905837 | |||
| 9cf623a209 | |||
| 06aa3ef3d3 | |||
| 0384104e23 | |||
| 325ec98009 | |||
| 47acdea74a | |||
| 71606b289c | |||
| e342a7509a | |||
| 27ac58bd70 | |||
| 406719c3da | |||
| 957570e4a3 | |||
| eeb6c96a89 | |||
| 0b12e49795 | |||
| 87646e5db4 | |||
| 29d6bb79e1 | |||
| c2924bbafa | |||
| a2f109dcc3 | |||
| ba5ffa2dca | |||
| c131e4b390 | |||
| 7fd15aa2bd | |||
| c45c966031 | |||
| d18c742779 | |||
| 4957ae5838 | |||
| 31d6d3ef5c | |||
| 2325c511e7 | |||
| d865156967 | |||
| fbc0bd2e90 | |||
| 70f5f55abf | |||
| 69ecb562e7 | |||
| 5062abe4e7 | |||
| c7007e7584 | |||
| 09705ca9b2 | |||
| ea6b0b5d0f | |||
| bbf852d87f | |||
| 6392b986e7 | |||
| 32d30d96cf | |||
| 46516efa85 | |||
| 84b2147b85 | |||
| 1727a71cb6 | |||
| fb9e10fe25 | |||
| 4e277e6323 | |||
| ba327b7a5c | |||
| 8eb21304ab | |||
| b83a3f6e87 | |||
| 289b47e657 | |||
| c20308b79e | |||
| 4c41e9bde7 | |||
| 2f5223564e | |||
| 28615a765d | |||
| d1446ad75c | |||
| e401a56b96 | |||
| 22650c89fb | |||
| c62a17a2fb | |||
| 713e289ae7 | |||
| 69784a0dbe | |||
| 3c2409c465 | |||
| 724cd32b0c | |||
| b62935d1a5 | |||
| ccc8c117dc | |||
| 86db4de10f | |||
| 12860892f8 | |||
| 694592ac1e | |||
| 285748e838 | |||
| 192034c41b | |||
| 5bfce8f345 | |||
| edd611f3b0 | |||
| aded2ebb90 | |||
| 5bda7afa05 | |||
| 341e924981 | |||
| 5a9ae7cefe | |||
| 3d59e8aadf | |||
| 4cf1d1af22 | |||
| 05b8214e6a | |||
| 35d2da32bd | |||
| 0968e74266 | |||
| 57dd6a0656 | |||
| 7318ed627b | |||
| 5b2ad2d5dc | |||
| faba6e205f | |||
| 3261149aa3 | |||
| bd7e18bc57 | |||
| 643b3bc8f3 | |||
| 91b626e2ef | |||
| bf8297afe0 | |||
| 3f03f84ce2 | |||
| 8a72188828 | |||
| d325aa1877 | |||
| 7aedf3a576 | |||
| eaf4815c1f | |||
| a913b2bb93 | |||
| 1632876edf | |||
| 0e1f76f77e | |||
| ae67a5a9d3 | |||
| 292bd62c71 | |||
| 0e512ee9f0 | |||
| 31ac764239 | |||
| b228f6d180 | |||
| e678450a69 | |||
| 552c3f3e18 | |||
| 5b36e4e30f | |||
| cd6d06a22b | |||
| 669cf21a6b | |||
| 9a86ef7632 | |||
| f47cadf75d | |||
| 2923b02c6e | |||
| 4b9ba0fb26 | |||
| 106d34c80a | |||
| 0b06109412 | |||
| 2073af5790 | |||
| 9b4ac45d2f | |||
| a45a17f65e | |||
| c5593e75b3 | |||
| c90a976370 | |||
| d144382dc9 | |||
| 78827c5e00 | |||
| ab1e734cd7 | |||
| 888958ad6c | |||
| d19f36bea1 | |||
| 096c9356de | |||
| 03dea563f4 | |||
| 2e83ae2de7 | |||
| 77b70970f7 | |||
| c9b2db73ca | |||
| ba2e6b0b4f | |||
| 8523a64c4b | |||
| 9fef18e31d | |||
| aaea391b62 | |||
| 7206668f7c | |||
| 7729de07d3 | |||
| 73078f305f | |||
| ea7add4837 |
@ -36,11 +36,7 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
;;
|
||||
rocm*)
|
||||
BASE_TARGET=rocm
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -168,6 +168,18 @@ case "$tag" in
|
||||
VISION=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.11-clang12)
|
||||
ANACONDA_PYTHON_VERSION=3.11
|
||||
CLANG_VERSION=12
|
||||
VISION=no
|
||||
TRITON=no
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-clang12)
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
CLANG_VERSION=12
|
||||
VISION=no
|
||||
TRITON=no
|
||||
;;
|
||||
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
|
||||
if [[ $tag =~ "jammy" ]]; then
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
@ -195,9 +207,9 @@ case "$tag" in
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
|
||||
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
GCC_VERSION=13
|
||||
VISION=yes
|
||||
XPU_VERSION=2025.2
|
||||
NINJA_VERSION=1.9.0
|
||||
@ -248,6 +260,12 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-triton-cpu)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
@ -369,6 +387,7 @@ docker build \
|
||||
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
|
||||
--build-arg "EXECUTORCH=${EXECUTORCH}" \
|
||||
--build-arg "HALIDE=${HALIDE}" \
|
||||
--build-arg "PALLAS=${PALLAS}" \
|
||||
--build-arg "XPU_VERSION=${XPU_VERSION}" \
|
||||
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
|
||||
--build-arg "ACL=${ACL:-}" \
|
||||
|
||||
1
.ci/docker/ci_commit_pins/jax.txt
Normal file
1
.ci/docker/ci_commit_pins/jax.txt
Normal file
@ -0,0 +1 @@
|
||||
0.8.0
|
||||
40
.ci/docker/common/install_jax.sh
Executable file
40
.ci/docker/common/install_jax.sh
Executable file
@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
# Get the pinned JAX version (same for all CUDA versions)
|
||||
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
|
||||
|
||||
function install_jax_12() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
|
||||
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
|
||||
}
|
||||
|
||||
function install_jax_13() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
|
||||
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
|
||||
}
|
||||
|
||||
# idiomatic parameter and option handling in sh
|
||||
while test $# -gt 0
|
||||
do
|
||||
case "$1" in
|
||||
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
|
||||
;;
|
||||
13.0|13.0.*) install_jax_13;
|
||||
;;
|
||||
*) echo "bad argument $1"; exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
56
.ci/docker/common/install_libgomp.sh
Normal file
56
.ci/docker/common/install_libgomp.sh
Normal file
@ -0,0 +1,56 @@
|
||||
#!/bin/bash
|
||||
# Script used only in CD pipeline
|
||||
|
||||
set -ex
|
||||
|
||||
# install dependencies
|
||||
dnf -y install gmp-devel libmpc-devel texinfo flex bison
|
||||
|
||||
cd /usr/local/src
|
||||
# fetch source for gcc 13
|
||||
git clone --depth 1 --single-branch -b releases/gcc-13.3.0 https://github.com/gcc-mirror/gcc.git gcc-13.3.0
|
||||
|
||||
mkdir -p gcc-13.3.0/build-gomp
|
||||
cd gcc-13.3.0/build-gomp
|
||||
|
||||
# configure gcc build
|
||||
# I got these flags by:
|
||||
# 1. downloading the source rpm for gcc-11 on AlmaLinux 8 container
|
||||
# dnf install -y dnf-plugins-core rpmdevtools
|
||||
# dnf download --source libgomp
|
||||
# 2. extracting the gcc.spec from the source.
|
||||
# rpmdev-extract gcc-xx.src.rpm
|
||||
# 3. extracting optflags and ld_flags from gcc.spec:
|
||||
# rpm --eval '%{optflags}'
|
||||
# rpm --eval '%{build_ldflags}'
|
||||
#
|
||||
# I had to remove the following flags because they didn't compile for this version of libgomp:
|
||||
# -Werror=format-security
|
||||
# -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1
|
||||
# -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1
|
||||
#
|
||||
# I added -march=armv8-a -mtune=generic to make them explicit. I don't think they're strictly needed.
|
||||
|
||||
OPT_FLAGS='-O2 -march=armv8-a -mtune=generic'\
|
||||
' -fexceptions -g -grecord-gcc-switches -pipe -Wall'\
|
||||
' -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS'\
|
||||
' -fstack-protector-strong -fasynchronous-unwind-tables'\
|
||||
' -fstack-clash-protection'
|
||||
|
||||
LDFLAGS='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now'
|
||||
|
||||
CFLAGS="$OPT_FLAGS" \
|
||||
CXXFLAGS="$OPT_FLAGS" \
|
||||
LDFLAGS="$LDFLAGS" \
|
||||
../configure \
|
||||
--prefix=/usr \
|
||||
--libdir=/usr/lib64 \
|
||||
--enable-languages=c,c++ \
|
||||
--disable-multilib \
|
||||
--disable-bootstrap \
|
||||
--enable-libgomp
|
||||
|
||||
# only build libgomp
|
||||
make -j$(nproc) all-target-libgomp
|
||||
|
||||
make install-target-libgomp
|
||||
@ -9,7 +9,7 @@ set -xe
|
||||
|
||||
function install_ubuntu() {
|
||||
. /etc/os-release
|
||||
if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then
|
||||
if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then
|
||||
echo "Ubuntu version ${VERSION_CODENAME} not supported"
|
||||
exit
|
||||
fi
|
||||
@ -35,25 +35,24 @@ function install_ubuntu() {
|
||||
# The xpu-smi packages
|
||||
apt-get install -y flex bison xpu-smi
|
||||
|
||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||
# Compute and Media Runtimes
|
||||
# Compute and Media Runtimes
|
||||
if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then
|
||||
apt-get install -y \
|
||||
intel-opencl-icd intel-level-zero-gpu level-zero \
|
||||
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
|
||||
# Development Packages
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
|
||||
else # rolling driver
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||
else # jammy
|
||||
apt-get install -y \
|
||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||
fi
|
||||
# Development Packages
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||
|
||||
# Install Intel Support Packages
|
||||
apt-get install -y ${XPU_PACKAGES}
|
||||
@ -66,7 +65,7 @@ function install_ubuntu() {
|
||||
function install_rhel() {
|
||||
. /etc/os-release
|
||||
if [[ "${ID}" == "rhel" ]]; then
|
||||
if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||
if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||
echo "RHEL version ${VERSION_ID} not supported"
|
||||
exit
|
||||
fi
|
||||
@ -147,7 +146,7 @@ function install_sles() {
|
||||
XPU_DRIVER_VERSION=""
|
||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||
# Use GPU driver LTS releases
|
||||
XPU_DRIVER_VERSION="/lts/2350"
|
||||
XPU_DRIVER_VERSION="/lts/2523"
|
||||
fi
|
||||
|
||||
# Default use Intel® oneAPI Deep Learning Essentials 2025.1
|
||||
|
||||
@ -49,11 +49,7 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
fi
|
||||
BASE_TARGET=rocm
|
||||
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -50,6 +50,10 @@ RUN rm install_ninja.sh
|
||||
ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
|
||||
|
||||
# Build a newer version of libgomp than that supported in in Almalinux 8.
|
||||
COPY ./common/install_libgomp.sh install_libgomp.sh
|
||||
RUN bash ./install_libgomp.sh && rm install_libgomp.sh
|
||||
|
||||
# git236+ would refuse to run git commands in repos owned by other users
|
||||
# Which causes version check to fail, as pytorch repo is bind-mounted into the image
|
||||
# Override this behaviour by treating every folder as safe
|
||||
|
||||
@ -87,11 +87,7 @@ case ${image} in
|
||||
MANY_LINUX_VERSION="2_28"
|
||||
DEVTOOLSET_VERSION="11"
|
||||
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
|
||||
;;
|
||||
manylinux2_28-builder:xpu)
|
||||
|
||||
@ -143,6 +143,15 @@ COPY ci_commit_pins/halide.txt halide.txt
|
||||
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
|
||||
RUN rm install_halide.sh common_utils.sh halide.txt
|
||||
|
||||
ARG PALLAS
|
||||
ARG CUDA_VERSION
|
||||
# Install JAX with CUDA support (for Pallas)
|
||||
COPY ./common/install_jax.sh install_jax.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
|
||||
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
|
||||
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
|
||||
|
||||
ARG ONNX
|
||||
# Install ONNX dependencies
|
||||
COPY ./common/install_onnx.sh ./common/common_utils.sh ./
|
||||
|
||||
@ -8,9 +8,11 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
try:
|
||||
from typing import Any, Callable, Required, TypedDict # Python 3.11+
|
||||
from collections.abc import Callable # Python 3.11+
|
||||
from typing import Any, Required, TypedDict
|
||||
except ImportError:
|
||||
from typing import Any, Callable, TypedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import Required # Fallback for Python <3.11
|
||||
|
||||
|
||||
@ -168,14 +168,16 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/umf/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Enable XCCL build
|
||||
export USE_XCCL=1
|
||||
export USE_MPI=0
|
||||
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
|
||||
export USE_KINETO=0
|
||||
export TORCH_XPU_ARCH_LIST=pvc
|
||||
fi
|
||||
|
||||
|
||||
@ -208,6 +208,8 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||
# shellcheck disable=SC1091
|
||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
||||
# Check XPU status before testing
|
||||
timeout 30 xpu-smi discovery || true
|
||||
fi
|
||||
@ -337,7 +339,7 @@ test_python() {
|
||||
|
||||
test_python_smoke() {
|
||||
# Smoke tests for H100/B200
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
@ -824,6 +826,11 @@ test_inductor_halide() {
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_pallas() {
|
||||
python test/run_test.py --include inductor/test_pallas.py --verbose
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_triton_cpu() {
|
||||
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
|
||||
assert_git_not_dirty
|
||||
@ -1724,6 +1731,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
test_inductor_halide
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
|
||||
test_inductor_pallas
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||
test_inductor_triton_cpu
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||
|
||||
@ -70,7 +70,7 @@ sccache --zero-stats
|
||||
sccache --show-stats
|
||||
|
||||
# Build the wheel
|
||||
python -m build --wheel --no-build-isolation
|
||||
python -m build --wheel --no-isolation
|
||||
if ($LASTEXITCODE -ne 0) { exit 1 }
|
||||
|
||||
# Install the wheel locally
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
cfbc5c2f1c798991715a6b06bb3ce46478c4487c
|
||||
ccb801b88af136454798b945175c4c87e636ac33
|
||||
|
||||
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@ -138,7 +138,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -148,7 +149,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -158,7 +160,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
1
.github/nitpicks.yml
vendored
1
.github/nitpicks.yml
vendored
@ -10,3 +10,4 @@
|
||||
pathFilter:
|
||||
- 'torch/csrc/inductor/aoti_torch/c/*'
|
||||
- 'torch/csrc/inductor/aoti_torch/generated/*'
|
||||
- 'torch/csrc/stable/c/*'
|
||||
|
||||
6
.github/pytorch-probot.yml
vendored
6
.github/pytorch-probot.yml
vendored
@ -2,8 +2,8 @@ tracking_issue: 24422
|
||||
ciflow_tracking_issue: 64124
|
||||
ciflow_push_tags:
|
||||
- ciflow/b200
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/b200-distributed
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
@ -22,6 +22,8 @@ ciflow_push_tags:
|
||||
- ciflow/inductor-perf-test-nightly-xpu
|
||||
- ciflow/inductor-periodic
|
||||
- ciflow/inductor-rocm
|
||||
- ciflow/inductor-rocm-mi200
|
||||
- ciflow/inductor-rocm-mi300
|
||||
- ciflow/linux-aarch64
|
||||
- ciflow/mps
|
||||
- ciflow/nightly
|
||||
@ -33,11 +35,13 @@ ciflow_push_tags:
|
||||
- ciflow/quantization-periodic
|
||||
- ciflow/riscv64
|
||||
- ciflow/rocm
|
||||
- ciflow/rocm-mi200
|
||||
- ciflow/rocm-mi300
|
||||
- ciflow/rocm-mi355
|
||||
- ciflow/rocm-navi31
|
||||
- ciflow/s390
|
||||
- ciflow/slow
|
||||
- ciflow/slow-rocm-mi200
|
||||
- ciflow/torchbench
|
||||
- ciflow/triton_binaries
|
||||
- ciflow/trunk
|
||||
|
||||
3
.github/scripts/delete_old_branches.py
vendored
3
.github/scripts/delete_old_branches.py
vendored
@ -1,10 +1,11 @@
|
||||
# Delete old branches
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from github_utils import gh_fetch_json_dict, gh_graphql
|
||||
from gitutils import GitRepo
|
||||
|
||||
3
.github/scripts/filter_test_configs.py
vendored
3
.github/scripts/filter_test_configs.py
vendored
@ -8,10 +8,11 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import cache
|
||||
from logging import info
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import yaml
|
||||
|
||||
3
.github/scripts/get_workflow_job_id.py
vendored
3
.github/scripts/get_workflow_job_id.py
vendored
@ -11,7 +11,8 @@ import sys
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
|
||||
3
.github/scripts/github_utils.py
vendored
3
.github/scripts/github_utils.py
vendored
@ -3,8 +3,9 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from typing import Any, cast, Optional, Union
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
4
.github/scripts/gitutils.py
vendored
4
.github/scripts/gitutils.py
vendored
@ -4,10 +4,10 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
4
.github/scripts/trymerge.py
vendored
4
.github/scripts/trymerge.py
vendored
@ -17,12 +17,12 @@ import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
from warnings import warn
|
||||
|
||||
import yaml
|
||||
|
||||
7
.github/workflows/docker-builds.yml
vendored
7
.github/workflows/docker-builds.yml
vendored
@ -56,6 +56,8 @@ jobs:
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
|
||||
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
|
||||
pytorch-linux-jammy-py3.10-clang12,
|
||||
pytorch-linux-jammy-py3.11-clang12,
|
||||
pytorch-linux-jammy-py3.12-clang12,
|
||||
pytorch-linux-jammy-py3.13-clang12,
|
||||
pytorch-linux-jammy-py3.14-clang12,
|
||||
pytorch-linux-jammy-rocm-n-py3,
|
||||
@ -65,9 +67,10 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3-clang18-asan,
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
|
||||
@ -83,8 +83,8 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
|
||||
@ -7,7 +7,7 @@ on:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/inductor-rocm/*
|
||||
- ciflow/inductor-rocm-mi200/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
1
.github/workflows/inductor-rocm-mi300.yml
vendored
1
.github/workflows/inductor-rocm-mi300.yml
vendored
@ -7,6 +7,7 @@ on:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/inductor-rocm/*
|
||||
- ciflow/inductor-rocm-mi300/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
|
||||
26
.github/workflows/inductor-unittest.yml
vendored
26
.github/workflows/inductor-unittest.yml
vendored
@ -81,6 +81,32 @@ jobs:
|
||||
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-build:
|
||||
name: inductor-pallas-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-test:
|
||||
name: inductor-pallas-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: inductor-pallas-build
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-triton-cpu-build:
|
||||
name: inductor-triton-cpu-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
1
.github/workflows/periodic-rocm-mi200.yml
vendored
1
.github/workflows/periodic-rocm-mi200.yml
vendored
@ -11,7 +11,6 @@ on:
|
||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||
push:
|
||||
tags:
|
||||
- ciflow/periodic/*
|
||||
- ciflow/periodic-rocm-mi200/*
|
||||
branches:
|
||||
- release/*
|
||||
|
||||
1
.github/workflows/periodic-rocm-mi300.yml
vendored
1
.github/workflows/periodic-rocm-mi300.yml
vendored
@ -11,6 +11,7 @@ on:
|
||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||
push:
|
||||
tags:
|
||||
- ciflow/periodic/*
|
||||
- ciflow/periodic-rocm-mi300/*
|
||||
branches:
|
||||
- release/*
|
||||
|
||||
8
.github/workflows/pull.yml
vendored
8
.github/workflows/pull.yml
vendored
@ -342,16 +342,16 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-build:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-build:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
||||
# sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },
|
||||
|
||||
@ -5,7 +5,7 @@ on:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/rocm/*
|
||||
- ciflow/rocm-mi200/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
1
.github/workflows/rocm-mi300.yml
vendored
1
.github/workflows/rocm-mi300.yml
vendored
@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/rocm/*
|
||||
- ciflow/rocm-mi300/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
|
||||
81
.github/workflows/slow-rocm-mi200.yml
vendored
Normal file
81
.github/workflows/slow-rocm-mi200.yml
vendored
Normal file
@ -0,0 +1,81 @@
|
||||
# This workflow is dedicated to host slow jobs that are run only periodically because
|
||||
# they are too slow to run in every commit. The list of slow tests can be found in
|
||||
# https://github.com/pytorch/test-infra/blob/generated-stats/stats/slow-tests.json
|
||||
name: slow-rocm-mi200
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/slow/*
|
||||
- ciflow/slow-rocm-mi200/*
|
||||
schedule:
|
||||
- cron: 0 */3 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
llm-td:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: before-test
|
||||
uses: ./.github/workflows/llm_td_retrieval.yml
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
target-determination:
|
||||
name: before-test
|
||||
uses: ./.github/workflows/target_determination.yml
|
||||
needs: llm-td
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
30
.github/workflows/slow.yml
vendored
30
.github/workflows/slow.yml
vendored
@ -105,36 +105,6 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3_10-clang18-asan-build:
|
||||
name: linux-jammy-py3.10-clang18-asan
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
5
.github/workflows/upload-test-stats.yml
vendored
5
.github/workflows/upload-test-stats.yml
vendored
@ -11,15 +11,16 @@ on:
|
||||
- inductor
|
||||
- unstable
|
||||
- slow
|
||||
- slow-rocm-mi200
|
||||
- unstable-periodic
|
||||
- inductor-periodic
|
||||
- rocm
|
||||
- rocm-mi200
|
||||
- rocm-mi300
|
||||
- rocm-mi355
|
||||
- inductor-micro-benchmark
|
||||
- inductor-micro-benchmark-x86
|
||||
- inductor-cu124
|
||||
- inductor-rocm
|
||||
- inductor-rocm-mi200
|
||||
- inductor-rocm-mi300
|
||||
- mac-mps
|
||||
- linux-aarch64
|
||||
|
||||
20
.github/workflows/xpu.yml
vendored
20
.github/workflows/xpu.yml
vendored
@ -47,15 +47,15 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-build:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-build:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -74,17 +74,17 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-test:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-test:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: linux-jammy-xpu-n-py3_10-build
|
||||
needs: linux-noble-xpu-n-py3_10-build
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
windows-xpu-n-1-build:
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -127,6 +127,7 @@ torch/test/
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
torch/_inductor/kernel/vendored_templates/*
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
|
||||
@ -143,7 +143,8 @@ init_command = [
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
|
||||
'numpy==2.3.4 ; python_version >= "3.14"',
|
||||
'expecttest==0.3.0',
|
||||
'pyrefly==0.36.2',
|
||||
'sympy==1.13.3',
|
||||
@ -1401,7 +1402,7 @@ init_command = [
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'usort==1.0.8.post1',
|
||||
'isort==6.0.1',
|
||||
'ruff==0.13.1', # sync with RUFF
|
||||
'ruff==0.14.4', # sync with RUFF
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
@ -1536,7 +1537,7 @@ init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.13.1', # sync with PYFMT
|
||||
'ruff==0.14.4', # sync with PYFMT
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
||||
@ -210,8 +210,12 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
||||
/test/inductor/test_flex_attention.py @drisspg
|
||||
/test/inductor/test_flex_decoding.py @drisspg
|
||||
|
||||
# Low Precision GEMMs
|
||||
# Low Precision & Grouped GEMMs
|
||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
|
||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
||||
|
||||
@ -174,6 +174,12 @@ class TORCH_API Context {
|
||||
static long versionCuDNN() {
|
||||
return detail::getCUDAHooks().versionCuDNN();
|
||||
}
|
||||
static long versionRuntimeCuDNN() {
|
||||
return detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
}
|
||||
static long versionCuDNNFrontend() {
|
||||
return detail::getCUDAHooks().versionCuDNNFrontend();
|
||||
}
|
||||
static bool hasCuSOLVER() {
|
||||
return detail::getCUDAHooks().hasCuSOLVER();
|
||||
}
|
||||
|
||||
@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda.h> // For CUDA_VERSION
|
||||
@ -61,12 +62,9 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
case enum_type: { \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
||||
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
||||
@ -95,14 +93,6 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// The AT_DISPATCH_* family of macros provides the ability to
|
||||
// conveniently generate specializations of a kernel over all of the
|
||||
// dtypes we care about in PyTorch. We call it "dispatch" because
|
||||
@ -190,27 +180,13 @@ inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
// but we're just being safe (and it doesn't hurt.) Note we must
|
||||
// use it to shut up warnings about unused store.
|
||||
|
||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const auto& the_type = TYPE; \
|
||||
constexpr const char* at_dispatch_name = NAME; \
|
||||
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED( \
|
||||
false, \
|
||||
'"', \
|
||||
at_dispatch_name, \
|
||||
"\" not implemented for '", \
|
||||
toString(_st), \
|
||||
"'"); \
|
||||
} \
|
||||
C10_DIAGNOSTIC_POP() \
|
||||
}()
|
||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH_TMPL( \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE, \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
|
||||
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
// This is a new implementation of the AT_DISPATCH macro family from
|
||||
@ -74,41 +79,19 @@
|
||||
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
||||
// relied on GPT4 to help me get it right.
|
||||
|
||||
// Public API macros
|
||||
|
||||
// See documentation above
|
||||
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
||||
|
||||
// This macro lets you pass an arbitrary expression that may contain internal
|
||||
// commas to another macro without having the commas causing the expression
|
||||
// to be interpreted as being multiple arguments
|
||||
#define AT_WRAP(...) __VA_ARGS__
|
||||
|
||||
#define AT_FLOAT8_TYPES \
|
||||
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
||||
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
|
||||
|
||||
#define AT_INTEGRAL_TYPES \
|
||||
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
||||
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
||||
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
||||
#define AT_INTEGRAL_TYPES_V2 \
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
||||
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
||||
// NB: not *actually* all types
|
||||
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
#define AT_ALL_TYPES_AND_COMPLEX \
|
||||
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
||||
|
||||
// Helper macros
|
||||
THO_DISPATCH_V2_TMPL( \
|
||||
AT_DISPATCH_SWITCH, \
|
||||
AT_DISPATCH_CASE, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(BODY), \
|
||||
__VA_ARGS__)
|
||||
|
||||
// Unused helper macros, kept for BC:
|
||||
#define AT_AP_VAR(N, T, ...) \
|
||||
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
||||
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
||||
#define AT_CONCAT_AUX(a, b) a##b
|
||||
#define AT_EXPAND(X) X
|
||||
|
||||
// Ensure we never have too many scalar types for the expansion here to
|
||||
// support. To bump this, you must regenerate the macros below.
|
||||
@ -119,12 +102,6 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
|
||||
|
||||
num_args = 60
|
||||
|
||||
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
||||
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
||||
|
||||
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
||||
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
||||
|
||||
for i in range(1, num_args+1):
|
||||
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
||||
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
||||
@ -135,8 +112,6 @@ for i in range(1, num_args+1):
|
||||
// Begin generated code
|
||||
// clang-format off
|
||||
|
||||
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
|
||||
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
||||
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
||||
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
||||
|
||||
@ -226,8 +226,8 @@ template <
|
||||
typename B = HostBlock<S>>
|
||||
struct CachingHostAllocatorImpl {
|
||||
virtual ~CachingHostAllocatorImpl() {
|
||||
active_ = false;
|
||||
if (pinned_use_background_threads()) {
|
||||
if (active_) {
|
||||
active_ = false;
|
||||
getBackgroundThreadPool()->waitWorkComplete();
|
||||
}
|
||||
}
|
||||
@ -260,6 +260,7 @@ struct CachingHostAllocatorImpl {
|
||||
if (pinned_use_background_threads()) {
|
||||
// Launch the background thread and process events in a loop.
|
||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||
active_ = true;
|
||||
getBackgroundThreadPool()->run([&]() {
|
||||
while (active_) {
|
||||
process_events();
|
||||
@ -683,9 +684,9 @@ struct CachingHostAllocatorImpl {
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the object is active.
|
||||
// Indicates whether the event-processing thread pool is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{true};
|
||||
std::atomic<bool> active_{false};
|
||||
protected:
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
@ -1597,7 +1597,7 @@ bool gemm_and_bias(
|
||||
}
|
||||
|
||||
using opmath_t = at::opmath_type<Dtype>;
|
||||
opmath_t beta_val = 0; // bias is added in epilogue
|
||||
opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr
|
||||
|
||||
cudaDataType_t abType = CUDA_R_32F;
|
||||
cudaDataType_t cType = CUDA_R_32F;
|
||||
@ -1686,15 +1686,22 @@ bool gemm_and_bias(
|
||||
_syncCurrentWithCarveoutStream(stream, true);
|
||||
}
|
||||
#endif
|
||||
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
|
||||
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
|
||||
}
|
||||
const auto epilogue = [&]() -> cublasLtEpilogue_t {
|
||||
// The cuBLAS documentation indicates that
|
||||
// *_<ACTIVATION>_BIAS = *_<ACTIVATION>,
|
||||
// but we keep it verbose here for clarity.
|
||||
switch (activation) {
|
||||
case GEMMAndBiasActivationEpilogue::RELU:
|
||||
return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU;
|
||||
case GEMMAndBiasActivationEpilogue::GELU:
|
||||
return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU;
|
||||
default:
|
||||
return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT;
|
||||
}
|
||||
}();
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
|
||||
|
||||
if (bias != nullptr) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
|
||||
if (bias) {
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||
#include <cudnn_frontend.h>
|
||||
#endif
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
@ -351,6 +352,26 @@ long CUDAHooks::versionCuDNN() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionRuntimeCuDNN() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#ifndef USE_STATIC_CUDNN
|
||||
return cudnnGetVersion();
|
||||
#else
|
||||
return CUDNN_VERSION;
|
||||
#endif
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionCuDNNFrontend() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
return CUDNN_FRONTEND_VERSION;
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionMIOpen() const {
|
||||
#if AT_ROCM_ENABLED()
|
||||
return MIOPEN_VERSION_MAJOR * 10000 +
|
||||
|
||||
@ -49,6 +49,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCUDART() const override;
|
||||
long versionCUDART() const override;
|
||||
long versionCuDNN() const override;
|
||||
long versionRuntimeCuDNN() const override;
|
||||
long versionCuDNNFrontend() const override;
|
||||
long versionMIOpen() const override;
|
||||
std::string showConfig() const override;
|
||||
double batchnormMinEpsilonCuDNN() const override;
|
||||
|
||||
@ -174,6 +174,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionRuntimeCuDNN() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionCuDNNFrontend() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionMIOpen() const {
|
||||
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
@ -157,6 +157,8 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
||||
DispatchKey::Negative,
|
||||
DispatchKey::Conjugate,
|
||||
DispatchKey::XLA,
|
||||
DispatchKey::XPU,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
|
||||
@ -409,7 +409,7 @@ struct ConvParams {
|
||||
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
|
||||
return false;
|
||||
}
|
||||
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
|
||||
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
// broken on cuDNN 9.8 - 9.14
|
||||
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
||||
@ -453,7 +453,7 @@ struct ConvParams {
|
||||
}
|
||||
// native kernel doesn't support 64-bit non-splittable case
|
||||
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
|
||||
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
|
||||
if (cudnn_version < 0 || cudnn_version > 91000) {
|
||||
|
||||
@ -147,14 +147,24 @@ static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
||||
/*
|
||||
* Check whether for the given input we want to enable the Lt interface
|
||||
*/
|
||||
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
static bool isInputCompliesAddmmCudaLt(
|
||||
Tensor& result,
|
||||
const Tensor& self,
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
Activation activation
|
||||
) {
|
||||
#ifdef USE_ROCM
|
||||
// Implies 2D bias which we currently not send through Lt.
|
||||
// TODO: this check is done pre col-major input preparation,
|
||||
// so, this condition can be ralexed in cases when a col-major
|
||||
// copy of result is needed.
|
||||
if (result.is_same(self)) {
|
||||
if (self.is_same(result) || self.dim() == 2) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||
@ -169,13 +179,33 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const
|
||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||
const auto scalar_type = mat1.scalar_type();
|
||||
return (beta.toComplexDouble() == 1.0
|
||||
// NOTE: row-major result is important when bias is 1D.
|
||||
// This is because Lt broadcasts 1D bias over the columns
|
||||
// while the aten::addmm API broadcasts it over the rows,
|
||||
// and this is in conjuction with the data preparation
|
||||
// procedure that does not transpose arguments with
|
||||
// col-major result. For col-major result we need
|
||||
// to explicitly transpose the problem so that bias is
|
||||
// correctly applied.
|
||||
// TODO: enable col-major result if needed.
|
||||
// TODO: no need to check result's layout when
|
||||
// !result.is_same(self) and self.dim() == 2, because
|
||||
// self needs to be copied into result and the bias ptr
|
||||
// will be ignored.
|
||||
&& result.dim() == 2 && result.is_contiguous()
|
||||
// Conditions for bias to be fusable
|
||||
&& (
|
||||
self.is_contiguous() &&
|
||||
// NOTE: fine to have 1-len dims to the left from the right-most one
|
||||
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
||||
self.sizes().back() == mat2_sizes[1]
|
||||
( // Conditions for bias to be fusable -- implies direct Lt path without copies.
|
||||
self.is_contiguous() &&
|
||||
// NOTE: fine to have 1-len dims to the left from the right-most one
|
||||
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
||||
self.sizes().back() == mat2_sizes[1]
|
||||
)
|
||||
|| ( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self),
|
||||
// and we need to copy self into result otherwise, so the self's layout becomes irrelevant.
|
||||
// See also TODO from above.
|
||||
activation != Activation::None && // Lt is faster when activation is fused
|
||||
(self.dim() == 2 && at::is_expandable_to(self.sizes(), {mat1_sizes[0], mat2_sizes[1]}))
|
||||
)
|
||||
)
|
||||
&& ( // some dtype restrictions
|
||||
#ifndef USE_ROCM
|
||||
@ -270,7 +300,16 @@ bool launchGemmAndBiasCublasLt(
|
||||
const Scalar& alpha,
|
||||
Activation activation = Activation::None
|
||||
) {
|
||||
const auto* self_ptr = self.const_data_ptr<scalar_t>();
|
||||
// We apply bias in the epilogue only when it is 1D,
|
||||
// or when it can be squeezed to 1D.
|
||||
// self_ptr == nullptr implies ignore bias epilogue
|
||||
// and use standard gemm-like API.
|
||||
const auto* self_ptr = [&]() -> auto {
|
||||
if (self.dim() == 1 || self.squeeze().dim() == 1) {
|
||||
return self.const_data_ptr<scalar_t>();
|
||||
}
|
||||
return static_cast<const scalar_t*>(nullptr);
|
||||
}();
|
||||
|
||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||
@ -356,7 +395,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||
#endif
|
||||
// Condition on the input
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
|
||||
// }
|
||||
|
||||
at::ScalarType scalar_type = mat1.scalar_type();
|
||||
@ -366,19 +405,20 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
if (!result.is_same(self)) {
|
||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||
|
||||
// We use bias ptr in the Lt path only when bias is 1D
|
||||
const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
|
||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||
if (disable_addmm_cuda_lt) {
|
||||
// When in non-Lt path we do expand self even before
|
||||
if (!use_bias_ptr_lt) {
|
||||
// We do expand self even before
|
||||
// check for beta != 0.0 to make sure that
|
||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||
// runs green.
|
||||
return expand_size(self, result.sizes(), "addmm");
|
||||
}
|
||||
// copy next, should broadcast
|
||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||
}();
|
||||
// We copy bias when in the non-Lt path
|
||||
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
|
||||
// We do not copy bias only when we need the bias ptr
|
||||
if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) {
|
||||
// NOTE: self should broadcast over result
|
||||
at::native::copy_(result, *self_maybe_expanded);
|
||||
}
|
||||
|
||||
@ -884,6 +884,69 @@ struct type_specialized_kernel_launcher {
|
||||
}
|
||||
};
|
||||
|
||||
template <int arg_index>
|
||||
struct type_specialized_broadcast_kernel_launcher {
|
||||
template <
|
||||
typename func_t,
|
||||
typename array_t,
|
||||
typename dtypes_t,
|
||||
typename calc_t>
|
||||
static void apply(
|
||||
int64_t numel,
|
||||
func_t f,
|
||||
array_t data,
|
||||
dtypes_t dtypes,
|
||||
calc_t offset_calc) {
|
||||
using traits = function_traits<func_t>;
|
||||
using ret_t = typename traits::result_type;
|
||||
using arg0_t = typename traits::template arg<0>::type;
|
||||
using arg1_t = typename traits::template arg<1>::type;
|
||||
if (dtypes[0] == rt_binary_specializations[arg_index][0] &&
|
||||
dtypes[1] == rt_binary_specializations[arg_index][1] &&
|
||||
dtypes[2] == rt_binary_specializations[arg_index][2]) {
|
||||
using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0]>;
|
||||
using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1]>;
|
||||
using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2]>;
|
||||
constexpr int grp_sz = 128;
|
||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
||||
if (unrl) {
|
||||
auto offsets0 = offset_calc.get(idx);
|
||||
auto offsets1 = offset_calc.get(idx + grp_sz);
|
||||
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
|
||||
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
|
||||
void* out0 = data[0] + offsets0[0];
|
||||
void* out1 = data[0] + offsets1[0];
|
||||
void* out2 = data[0] + offsets2[0];
|
||||
void* out3 = data[0] + offsets3[0];
|
||||
auto u = c10::load<arg0_cpp_t>(data[1] + offsets0[1]);
|
||||
auto v = c10::load<arg1_cpp_t>(data[2] + offsets0[2]);
|
||||
ret_t result0 = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
|
||||
auto u1 = c10::load<arg0_cpp_t>(data[1] + offsets1[1]);
|
||||
auto v1 = c10::load<arg1_cpp_t>(data[2]+ offsets1[2]);
|
||||
ret_t result1 = f(c10::convert<arg0_t>(u1), c10::convert<arg1_t>(v1));
|
||||
auto u2 = c10::load<arg0_cpp_t>(data[1] + offsets2[1]);
|
||||
auto v2 = c10::load<arg1_cpp_t>(data[2] + offsets2[2]);
|
||||
ret_t result2 = f(c10::convert<arg0_t>(u2), c10::convert<arg1_t>(v2));
|
||||
auto u3 = c10::load<arg0_cpp_t>(data[1] + offsets3[1]);
|
||||
auto v3 = c10::load<arg1_cpp_t>(data[2] + offsets3[2]);
|
||||
ret_t result3 = f(c10::convert<arg0_t>(u3), c10::convert<arg1_t>(v3));
|
||||
*(ret_cpp_t*)out0 = c10::convert<ret_cpp_t>(result0);
|
||||
*(ret_cpp_t*)out1 = c10::convert<ret_cpp_t>(result1);
|
||||
*(ret_cpp_t*)out2 = c10::convert<ret_cpp_t>(result2);
|
||||
*(ret_cpp_t*)out3 = c10::convert<ret_cpp_t>(result3);
|
||||
} else {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
void* out = data[0] + offsets[0];
|
||||
auto u = c10::load<arg0_cpp_t>(data[1] + offsets[1]);
|
||||
auto v = c10::load<arg1_cpp_t>(data[2] + offsets[2]);
|
||||
ret_t result = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
|
||||
*(ret_cpp_t*)out = c10::convert<ret_cpp_t>(result);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
#endif
|
||||
|
||||
@ -1002,6 +1065,32 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
|
||||
}
|
||||
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
||||
#ifdef USE_ROCM
|
||||
if (check_binary_rt_types_for_specialization(iter)) {
|
||||
// constexpr to reduce the amount of kernels generated for
|
||||
// broadcast elementwise with mexed dtypes and limit which functors are actually
|
||||
// applied to the load and store at compile time.
|
||||
using func_tuple = typename traits::ArgsTuple;
|
||||
if constexpr (
|
||||
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
|
||||
check_binary_functor_types_for_specialization<
|
||||
func_tuple,
|
||||
float,
|
||||
float,
|
||||
traits::arity,
|
||||
/*arg_num=*/0>::check()) {
|
||||
memory::detail::static_unroll<
|
||||
type_specialized_broadcast_kernel_launcher,
|
||||
rt_binary_specializations.size()>::with_args(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
dtypes,
|
||||
offset_calc
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int grp_sz = 128;
|
||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
||||
if (unrl) {
|
||||
|
||||
@ -133,7 +133,7 @@ at::Tensor quantized_convolution(
|
||||
// supported in conv.
|
||||
mask_weight = weight_zero_points.numel() > 1 ? 1 : 0;
|
||||
if (groups > 1 && weight_zero_points.numel() > 1)
|
||||
mask_weight = (2 ^ 0) | (2 ^ 1); // 2^0 (group) | 2^1 (output channel)
|
||||
mask_weight = (1 << 0) | (1 << 1); // 2^0 (group) | 2^1 (output channel)
|
||||
dnnl::primitive_attr pattr;
|
||||
|
||||
bool src_need_zp = (act_zero_point != 0);
|
||||
|
||||
@ -141,6 +141,9 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
|
||||
};
|
||||
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
@ -2803,7 +2803,7 @@
|
||||
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: floor_divide_out
|
||||
CPU, CUDA, MPS, MTIA: floor_divide_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: floor_divide_out_sparse_zerodim
|
||||
|
||||
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
@ -4292,6 +4292,7 @@
|
||||
dispatch:
|
||||
SparseCPU: sparse_sparse_matmul_cpu
|
||||
SparseCUDA: sparse_sparse_matmul_cuda
|
||||
SparseMPS: sparse_sparse_matmul_mps
|
||||
autogen: _sparse_sparse_matmul.out
|
||||
|
||||
- func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
|
||||
@ -4383,7 +4384,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: mv
|
||||
SparseCPU, SparseCUDA: mv_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
|
||||
|
||||
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
@ -9832,7 +9833,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method, function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: erfinv_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr
|
||||
tags: pointwise
|
||||
|
||||
@ -9841,7 +9842,7 @@
|
||||
structured_delegate: erfinv.out
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_
|
||||
tags: pointwise
|
||||
|
||||
@ -9851,7 +9852,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: erfinv_out
|
||||
SparseCPU, SparseCUDA: erfinv_sparse_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: erfinv_sparse_out
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: erfinv_sparse_csr_out
|
||||
tags: pointwise
|
||||
|
||||
|
||||
@ -10,6 +10,10 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_coalesce_native.h>
|
||||
#include <ATen/ops/repeat_interleave_native.h>
|
||||
#include <ATen/ops/cumsum.h>
|
||||
#include <ATen/ops/_sparse_sparse_matmul_native.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe.h>
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/add_native.h>
|
||||
@ -888,5 +892,114 @@ static void sparse_mask_intersection_out_mps_kernel(
|
||||
/*coalesce_mask=*/false);
|
||||
}
|
||||
|
||||
Tensor sparse_sparse_matmul_mps(const Tensor& mat1_, const Tensor& mat2_) {
|
||||
TORCH_CHECK(mat1_.is_sparse() && mat2_.is_sparse(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be sparse COO tensors");
|
||||
TORCH_CHECK(mat1_.is_mps() && mat2_.is_mps(),
|
||||
"sparse_sparse_matmul_mps: both inputs must be on MPS device");
|
||||
TORCH_CHECK(mat1_.dim() == 2 && mat2_.dim() == 2,
|
||||
"sparse_sparse_matmul_mps: both inputs must be 2D matrices");
|
||||
TORCH_CHECK(mat1_.dense_dim() == 0 && mat2_.dense_dim() == 0,
|
||||
"sparse_sparse_matmul_mps: only scalar values supported (dense_dim == 0)");
|
||||
TORCH_CHECK(mat1_.size(1) == mat2_.size(0),
|
||||
"mat1 and mat2 shapes cannot be multiplied (", mat1_.size(0), "x", mat1_.size(1), " and ", mat2_.size(0), "x", mat2_.size(1), ")");
|
||||
TORCH_CHECK(mat1_.scalar_type() == mat2_.scalar_type(),
|
||||
"sparse_sparse_matmul_mps: mat1 dtype ", mat1_.scalar_type(),
|
||||
" does not match mat2 dtype ", mat2_.scalar_type());
|
||||
|
||||
const auto device = mat1_.device();
|
||||
|
||||
auto A = mat1_.coalesce();
|
||||
auto B = mat2_.coalesce();
|
||||
|
||||
const auto I = A.size(0);
|
||||
const auto K = A.size(1);
|
||||
const auto N = B.size(1);
|
||||
|
||||
const auto nnzA = A._nnz();
|
||||
const auto nnzB = B._nnz();
|
||||
|
||||
// Early empty result, return an empty, coalesced tensor
|
||||
if (I == 0 || N == 0 || K == 0 || nnzA == 0 || nnzB == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
const auto computeDtype = at::result_type(mat1_, mat2_);
|
||||
|
||||
auto A_idx = A._indices().contiguous();
|
||||
auto A_val = A._values().to(computeDtype).contiguous();
|
||||
auto A_i = A_idx.select(0, 0).contiguous();
|
||||
auto A_k = A_idx.select(0, 1).contiguous();
|
||||
|
||||
auto B_idx = B._indices().contiguous();
|
||||
auto B_val = B._values().to(computeDtype).contiguous();
|
||||
auto B_k = B_idx.select(0, 0).contiguous();
|
||||
auto B_j = B_idx.select(0, 1).contiguous();
|
||||
|
||||
// csr-style row pointers for B by k (the shared dimension)
|
||||
Tensor row_ptr_B;
|
||||
{
|
||||
auto batch_ptr = at::tensor({0LL, nnzB}, at::device(device).dtype(at::kLong));
|
||||
row_ptr_B = at::empty({K + 1}, at::device(device).dtype(at::kLong));
|
||||
build_row_ptr_per_batch_mps(B_k, batch_ptr, /*B=*/1, /*I=*/K, row_ptr_B);
|
||||
}
|
||||
|
||||
auto row_ptr_B_lo = row_ptr_B.narrow(0, 0, K);
|
||||
auto row_ptr_B_hi = row_ptr_B.narrow(0, 1, K);
|
||||
auto deg_B = row_ptr_B_hi.sub(row_ptr_B_lo);
|
||||
|
||||
auto counts = deg_B.index_select(0, A_k);
|
||||
|
||||
const int64_t P = counts.sum().item<int64_t>();
|
||||
if (P == 0) {
|
||||
auto empty_idx = at::empty({2, 0}, at::device(device).dtype(at::kLong));
|
||||
auto empty_val = at::empty({0}, at::device(device).dtype(mat1_.scalar_type()));
|
||||
auto out = _sparse_coo_tensor_unsafe(empty_idx, empty_val, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
|
||||
auto group_ids = repeat_interleave_mps(counts);
|
||||
|
||||
// exclusive cumsum of counts
|
||||
auto offsets = cumsum(counts, /*dim=*/0).sub(counts);
|
||||
auto offsets_gather = offsets.index_select(0, group_ids);
|
||||
auto within = at::arange(P, at::device(device).dtype(at::kLong)).sub(offsets_gather);
|
||||
|
||||
// Map each output element to its source B row and position
|
||||
auto k_per_out = A_k.index_select(0, group_ids);
|
||||
auto start_in_B = row_ptr_B.index_select(0, k_per_out);
|
||||
auto seg_index = start_in_B.add(within);
|
||||
|
||||
// Assemble candidate coo pairs and values
|
||||
auto i_out = A_i.index_select(0, group_ids).contiguous();
|
||||
auto j_out = B_j.index_select(0, seg_index).contiguous();
|
||||
auto vA_out = A_val.index_select(0, group_ids).contiguous();
|
||||
auto vB_out = B_val.index_select(0, seg_index).contiguous();
|
||||
auto v_out = vA_out.mul(vB_out);
|
||||
|
||||
// build (2, P) indices
|
||||
auto out_indices = at::empty({2, P}, at::device(device).dtype(at::kLong)).contiguous();
|
||||
out_indices.select(0, 0).copy_(i_out);
|
||||
out_indices.select(0, 1).copy_(j_out);
|
||||
|
||||
auto result = _sparse_coo_tensor_unsafe(
|
||||
out_indices, v_out, {I, N}, mat1_.options().dtype(computeDtype));
|
||||
|
||||
result = result.coalesce();
|
||||
|
||||
if (result.scalar_type() != mat1_.scalar_type()) {
|
||||
auto cast_vals = result._values().to(mat1_.scalar_type());
|
||||
auto out = _sparse_coo_tensor_unsafe(result._indices(), cast_vals, {I, N}, mat1_.options());
|
||||
out._coalesced_(true);
|
||||
return out;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
} // namespace at::native
|
||||
@ -478,7 +478,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
|
||||
const auto s_k = params.key.sym_size(2);
|
||||
const auto d_qk = params.query.sym_size(3);
|
||||
const auto d_v = params.value.sym_size(3);
|
||||
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
|
||||
long cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
if (cudnn_version < 8903) {
|
||||
if (debug) {
|
||||
TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
|
||||
@ -709,7 +709,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
||||
return false;
|
||||
#endif
|
||||
#if defined(CUDNN_VERSION)
|
||||
static auto cudnn_version = cudnnGetVersion();
|
||||
static auto cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
if (params.dropout > 0.0 && cudnn_version > 91100 && cudnn_version < 91400) {
|
||||
if (debug) {
|
||||
TORCH_WARN(CUDNN_VERSION, " cuDNN version does not support droppout in SDPA (9.11 - 9.13).");
|
||||
|
||||
@ -52,19 +52,18 @@ def test_sparse_coo_and_csr(m, n, k, nnz, test_count):
|
||||
start.record()
|
||||
coo.matmul(mat)
|
||||
stop.record()
|
||||
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
coo_mean_time = sum(times) / len(times)
|
||||
coo_mean_time = sum(times) / len(times)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
csr.matmul(mat)
|
||||
stop.record()
|
||||
times.append(start.elapsed_time(stop))
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
csr.matmul(mat)
|
||||
stop.record()
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
csr_mean_time = sum(times) / len(times)
|
||||
csr_mean_time = sum(times) / len(times)
|
||||
|
||||
return coo_mean_time, csr_mean_time
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <optional>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -15,7 +17,8 @@ struct C10_API AutogradState {
|
||||
bool inference_mode,
|
||||
bool fw_grad_mode,
|
||||
bool multithreading_enabled)
|
||||
: grad_mode_(grad_mode),
|
||||
: graph_exec_group_(std::nullopt),
|
||||
grad_mode_(grad_mode),
|
||||
inference_mode_(inference_mode),
|
||||
fw_grad_mode_(fw_grad_mode),
|
||||
multithreading_enabled_(multithreading_enabled),
|
||||
@ -41,6 +44,10 @@ struct C10_API AutogradState {
|
||||
view_replay_enabled_ = view_replay_enabled;
|
||||
}
|
||||
|
||||
void set_graph_exec_group(std::optional<SafePyObject> group) {
|
||||
graph_exec_group_ = std::move(group);
|
||||
}
|
||||
|
||||
bool get_grad_mode() const {
|
||||
return grad_mode_;
|
||||
}
|
||||
@ -61,7 +68,12 @@ struct C10_API AutogradState {
|
||||
return view_replay_enabled_;
|
||||
}
|
||||
|
||||
const std::optional<SafePyObject>& get_graph_exec_group() const {
|
||||
return graph_exec_group_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<SafePyObject> graph_exec_group_;
|
||||
bool grad_mode_ : 1;
|
||||
bool inference_mode_ : 1;
|
||||
bool fw_grad_mode_ : 1;
|
||||
|
||||
@ -96,6 +96,10 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
|
||||
@ -66,6 +66,15 @@ def define_targets(rules):
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_test(
|
||||
name = "util/nofatal_test",
|
||||
srcs = ["util/nofatal_test.cpp"],
|
||||
deps = [
|
||||
"//c10/util:base",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_test(
|
||||
name = "util/ssize_test",
|
||||
srcs = ["util/ssize_test.cpp"],
|
||||
|
||||
53
c10/test/util/nofatal_test.cpp
Normal file
53
c10/test/util/nofatal_test.cpp
Normal file
@ -0,0 +1,53 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
inline void expectThrowsEq(T&& fn, const char* expected_msg) {
|
||||
try {
|
||||
std::forward<T>(fn)();
|
||||
} catch (const c10::Error& e) {
|
||||
EXPECT_TRUE(
|
||||
std::string(e.what_without_backtrace()).find(expected_msg) !=
|
||||
std::string::npos);
|
||||
return;
|
||||
}
|
||||
ADD_FAILURE() << "Expected to throw exception with message \"" << expected_msg
|
||||
<< "\" but didn't throw";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(NofatalTest, TorchCheckComparisons) {
|
||||
// quick make sure that no-op works as expected
|
||||
TORCH_CHECK_EQ(1, 1) << "i am a silly message " << 1;
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_EQ(1, 2) << "i am a silly message " << 1; },
|
||||
"Check failed: 1 == 2 (1 vs. 2). i am a silly message 1");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_NE(2, 2); }, "Check failed: 2 != 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_LT(2, 2); }, "Check failed: 2 < 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_LE(3, 2); }, "Check failed: 3 <= 2 (3 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_GT(2, 2); }, "Check failed: 2 > 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_GE(2, 3); }, "Check failed: 2 >= 3 (2 vs. 3).");
|
||||
expectThrowsEq(
|
||||
[]() {
|
||||
void* p = nullptr;
|
||||
TORCH_CHECK_NOTNULL(p);
|
||||
},
|
||||
"Check failed: 'p' must be non NULL.");
|
||||
|
||||
#if GTEST_HAS_DEATH_TEST
|
||||
#ifndef NDEBUG
|
||||
// if dbg build, DCHECK should result in deth
|
||||
EXPECT_DEATH(TORCH_DCHECK_EQ(1, 2), "Check failed");
|
||||
#else
|
||||
TORCH_DCHECK_EQ(1, 2); // no-op
|
||||
#endif
|
||||
#endif // GTEST_HAS_DEATH_TEST
|
||||
}
|
||||
@ -702,6 +702,98 @@ namespace c10::detail {
|
||||
#define TORCH_CHECK_ARG(cond, argN, ...) \
|
||||
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)
|
||||
|
||||
#ifndef FATAL_IF
|
||||
#ifdef C10_USE_GLOG
|
||||
#define FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::google::GLOG_FATAL) \
|
||||
.stream()
|
||||
#else
|
||||
#define FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream()
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef NON_FATAL_IF
|
||||
#ifdef C10_USE_GLOG
|
||||
#define NON_FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger( \
|
||||
__FILE__, __LINE__, ::google::GLOG_FATAL, false) \
|
||||
.stream()
|
||||
#else
|
||||
#define NON_FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL, false) \
|
||||
.stream()
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Binary comparison check macros
|
||||
#define TORCH_CHECK_OP(val1, val2, op) \
|
||||
NON_FATAL_IF(((val1)op(val2))) \
|
||||
<< "Check failed: " #val1 " " #op " " #val2 " (" << (val1) << " vs. " \
|
||||
<< (val2) << "). "
|
||||
|
||||
#define TORCH_DCHECK_OP(val1, val2, op) \
|
||||
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
|
||||
<< (val1) << " vs. " << (val2) << "). "
|
||||
|
||||
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
|
||||
// Debug versions of TORCH_CHECK_OP macros
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_EQ(val1, val2) TORCH_DCHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) TORCH_DCHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) TORCH_DCHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) TORCH_DCHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) TORCH_DCHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) TORCH_DCHECK_OP(val1, val2, >)
|
||||
#else // !NDEBUG
|
||||
// Optimized versions - generate no code
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, >)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Null pointer check macro
|
||||
#define TORCH_CHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), false)
|
||||
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), true)
|
||||
#else // !NDEBUG
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
TORCH_CHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Deprecated macros
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
@ -291,6 +291,32 @@ namespace c10 {
|
||||
using fLB::FLAGS_logtostderr;
|
||||
using fLI::FLAGS_minloglevel;
|
||||
using fLI::FLAGS_v;
|
||||
|
||||
MessageLogger::MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal)
|
||||
: stream_(), severity_(severity), exit_on_fatal_(exit_on_fatal) {}
|
||||
|
||||
MessageLogger::~MessageLogger() noexcept(false) {
|
||||
if (severity_ == ::google::GLOG_FATAL) {
|
||||
DealWithFatal();
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream& MessageLogger::stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
void MessageLogger::DealWithFatal() {
|
||||
if (exit_on_fatal_) {
|
||||
LOG(FATAL) << stream_.str();
|
||||
} else {
|
||||
throw c10::Error(stream_.str(), nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DEFINE_int(
|
||||
@ -412,17 +438,16 @@ void ShowLogInfoToStderr() {
|
||||
FLAGS_caffe2_log_level = GLOG_INFO;
|
||||
}
|
||||
|
||||
MessageLogger::MessageLogger(const char* file, int line, int severity)
|
||||
: severity_(severity) {
|
||||
MessageLogger::MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal)
|
||||
: severity_(severity), exit_on_fatal_(exit_on_fatal) {
|
||||
if (severity_ < FLAGS_caffe2_log_level) {
|
||||
// Nothing needs to be logged.
|
||||
return;
|
||||
}
|
||||
#ifdef ANDROID
|
||||
tag_ = "native";
|
||||
#else // !ANDROID
|
||||
tag_ = "";
|
||||
#endif // ANDROID
|
||||
|
||||
time_t rawtime = 0;
|
||||
time(&rawtime);
|
||||
@ -458,7 +483,7 @@ MessageLogger::MessageLogger(const char* file, int line, int severity)
|
||||
}
|
||||
|
||||
// Output the contents of the stream to the proper channel on destruction.
|
||||
MessageLogger::~MessageLogger() {
|
||||
MessageLogger::~MessageLogger() noexcept(false) {
|
||||
if (severity_ < FLAGS_caffe2_log_level) {
|
||||
// Nothing needs to be logged.
|
||||
return;
|
||||
@ -498,6 +523,18 @@ MessageLogger::~MessageLogger() {
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream& MessageLogger::stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
void MessageLogger::DealWithFatal() {
|
||||
if (exit_on_fatal_) {
|
||||
abort();
|
||||
} else {
|
||||
throw c10::Error(stream_.str(), nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#endif // !C10_USE_GLOG
|
||||
|
||||
74
c10/util/logging_common.h
Normal file
74
c10/util/logging_common.h
Normal file
@ -0,0 +1,74 @@
|
||||
#ifndef C10_UTIL_LOGGING_COMMON_H_
|
||||
#define C10_UTIL_LOGGING_COMMON_H_
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// MessageLogger that throws exceptions instead of aborting (glog version)
|
||||
// or logs and may abort (non-glog version).
|
||||
class C10_API MessageLogger {
|
||||
public:
|
||||
MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal = true);
|
||||
~MessageLogger() noexcept(false);
|
||||
|
||||
// Return the stream associated with the logger object.
|
||||
std::stringstream& stream();
|
||||
|
||||
private:
|
||||
// When there is a fatal log, and fatal == true, we abort
|
||||
// otherwise, we throw.
|
||||
void DealWithFatal();
|
||||
|
||||
#if defined(ANDROID) && !defined(C10_USE_GLOG)
|
||||
const char* tag_{"native"};
|
||||
#endif
|
||||
std::stringstream stream_;
|
||||
int severity_;
|
||||
bool exit_on_fatal_;
|
||||
};
|
||||
|
||||
// This class is used to explicitly ignore values in the conditional
|
||||
// logging macros. This avoids compiler warnings like "value computed
|
||||
// is not used" and "statement has no effect".
|
||||
class C10_API LoggerVoidify {
|
||||
public:
|
||||
LoggerVoidify() = default;
|
||||
// This has to be an operator with a precedence lower than << but
|
||||
// higher than ?:
|
||||
void operator&(const std::ostream& s [[maybe_unused]]) {}
|
||||
};
|
||||
|
||||
// Forward declarations for CheckNotNull functions
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal = true);
|
||||
|
||||
template <typename T>
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal = true);
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal = true);
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#endif // C10_UTIL_LOGGING_COMMON_H_
|
||||
@ -47,57 +47,53 @@ INSTANTIATE_FOR_CONTAINER(set)
|
||||
|
||||
#endif
|
||||
|
||||
#include <c10/util/logging_common.h>
|
||||
#include <glog/logging.h>
|
||||
|
||||
// Additional macros on top of glog
|
||||
#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
|
||||
#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2)
|
||||
#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2)
|
||||
#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2)
|
||||
#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2)
|
||||
#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2)
|
||||
namespace c10 {
|
||||
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2)
|
||||
#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2)
|
||||
#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2)
|
||||
#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2)
|
||||
#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2)
|
||||
#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2)
|
||||
#else // !NDEBUG
|
||||
// These versions generate no code in optimized mode.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_EQ(val1, val2)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_NE(val1, val2)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_LE(val1, val2)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_LT(val1, val2)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_GE(val1, val2)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_GT(val1, val2)
|
||||
#endif // NDEBUG
|
||||
[[noreturn]] void ThrowEnforceNotMet(
|
||||
const char* file,
|
||||
const int line,
|
||||
const char* condition,
|
||||
const std::string& msg,
|
||||
const void* caller);
|
||||
|
||||
// Check that a pointer is not null.
|
||||
#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val)
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
if (t == nullptr) {
|
||||
MessageLogger(file, line, ::google::GLOG_FATAL, fatal).stream()
|
||||
<< "Check failed: '" << names << "' must be non NULL. ";
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only version of TORCH_CHECK_NOTNULL
|
||||
#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val)
|
||||
#else // !NDEBUG
|
||||
// Optimized version - generates no code.
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
DCHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
template <typename T>
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
// Log with source location information override (to be used in generic
|
||||
// warning/error handlers implemented as functions, not macros)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/Flags.h>
|
||||
#include <c10/util/logging_common.h>
|
||||
|
||||
const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV";
|
||||
|
||||
@ -24,61 +25,40 @@ const int GLOG_ERROR = 2;
|
||||
const int GLOG_WARNING = 1;
|
||||
const int GLOG_INFO = 0;
|
||||
|
||||
class C10_API MessageLogger {
|
||||
public:
|
||||
MessageLogger(const char* file, int line, int severity);
|
||||
~MessageLogger();
|
||||
// Return the stream associated with the logger object.
|
||||
std::stringstream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
// When there is a fatal log, we simply abort.
|
||||
void DealWithFatal() {
|
||||
abort();
|
||||
}
|
||||
|
||||
const char* tag_;
|
||||
std::stringstream stream_;
|
||||
int severity_;
|
||||
};
|
||||
|
||||
// This class is used to explicitly ignore values in the conditional
|
||||
// logging macros. This avoids compiler warnings like "value computed
|
||||
// is not used" and "statement has no effect".
|
||||
class C10_API LoggerVoidify {
|
||||
public:
|
||||
LoggerVoidify() = default;
|
||||
// This has to be an operator with a precedence lower than << but
|
||||
// higher than ?:
|
||||
void operator&(const std::ostream& s [[maybe_unused]]) {}
|
||||
};
|
||||
|
||||
// Log a message and terminate.
|
||||
template <class T>
|
||||
void LogMessageFatal(const char* file, int line, const T& message) {
|
||||
MessageLogger(file, line, GLOG_FATAL).stream() << message;
|
||||
}
|
||||
|
||||
// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw
|
||||
// pointers and smart pointers.
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) {
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
if (t == nullptr) {
|
||||
LogMessageFatal(file, line, std::string(names));
|
||||
MessageLogger(file, line, GLOG_FATAL, fatal).stream()
|
||||
<< "Check failed: '" << names << "' must be non NULL. ";
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* CheckNotNull(const char* file, int line, const char* names, T* t) {
|
||||
return CheckNotNullCommon(file, line, names, t);
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(const char* file, int line, const char* names, T& t) {
|
||||
return CheckNotNullCommon(file, line, names, t);
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
} // namespace c10
|
||||
|
||||
@ -136,65 +116,6 @@ static_assert(
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
|
||||
#endif // NDEBUG
|
||||
|
||||
#define TORCH_CHECK_OP(val1, val2, op) \
|
||||
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
|
||||
<< (val1) << " vs. " << (val2) << ") "
|
||||
|
||||
// TORCH_CHECK_OP macro definitions
|
||||
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only versions of TORCH_CHECK_OP macros.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
#else // !NDEBUG
|
||||
// These versions generate no code in optimized mode.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, >)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Check that a pointer is not null.
|
||||
#define TORCH_CHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull( \
|
||||
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only version of TORCH_CHECK_NOTNULL
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull( \
|
||||
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
|
||||
#else // !NDEBUG
|
||||
// Optimized version - generates no code.
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
TORCH_CHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// ---------------------- Support for std objects --------------------------
|
||||
// These are adapted from glog to support a limited set of logging capability
|
||||
// for STL objects.
|
||||
|
||||
@ -926,15 +926,14 @@ class DeviceCachingAllocator {
|
||||
(release_cached_blocks() && alloc_block(params, true));
|
||||
}
|
||||
if (!block_found) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
const auto device_total =
|
||||
raw_device.get_info<sycl::info::device::global_mem_size>();
|
||||
// Estimate the available device memory when the SYCL runtime does not
|
||||
// support the corresponding aspect (ext_intel_free_memory).
|
||||
size_t device_free = device_prop.global_mem_size -
|
||||
size_t device_free = device_total -
|
||||
stats.reserved_bytes[static_cast<size_t>(StatType::AGGREGATE)]
|
||||
.current;
|
||||
auto& raw_device = c10::xpu::get_raw_device(device);
|
||||
// TODO: Remove the aspect check once the SYCL runtime bug is fixed on
|
||||
// affected devices.
|
||||
if (raw_device.has(sycl::aspect::ext_intel_free_memory)) {
|
||||
@ -1052,21 +1051,37 @@ class DeviceCachingAllocator {
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo() {
|
||||
const auto& device = c10::xpu::get_raw_device(device_index);
|
||||
const size_t total = device.get_info<sycl::info::device::global_mem_size>();
|
||||
TORCH_CHECK(
|
||||
device.has(sycl::aspect::ext_intel_free_memory),
|
||||
"The device (",
|
||||
device.get_info<sycl::info::device::name>(),
|
||||
") doesn't support querying the available free memory. ",
|
||||
"You can file an issue at https://github.com/pytorch/pytorch/issues ",
|
||||
"to help us prioritize its implementation.");
|
||||
const size_t free =
|
||||
device.get_info<sycl::ext::intel::info::device::free_memory>();
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
double getMemoryFraction() {
|
||||
if (!set_fraction) {
|
||||
return 1.0;
|
||||
}
|
||||
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
return static_cast<double>(allowed_memory_maximum) /
|
||||
static_cast<double>(device_prop.global_mem_size);
|
||||
static_cast<double>(device_total);
|
||||
}
|
||||
|
||||
void setMemoryFraction(double fraction) {
|
||||
c10::xpu::DeviceProp device_prop;
|
||||
c10::xpu::get_device_properties(&device_prop, device_index);
|
||||
auto device_total = device_prop.global_mem_size;
|
||||
const auto device_total =
|
||||
xpu::get_raw_device(device_index)
|
||||
.get_info<sycl::info::device::global_mem_size>();
|
||||
allowed_memory_maximum = static_cast<size_t>(fraction * device_total);
|
||||
set_fraction = true;
|
||||
}
|
||||
@ -1240,6 +1255,11 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryInfo();
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
@ -1941,6 +1941,7 @@ if(BUILD_TEST)
|
||||
foreach(test_src ${Caffe2_XPU_TEST_SRCS})
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} "${test_src}")
|
||||
torch_compile_options(${test_name})
|
||||
target_link_libraries(${test_name} torch_library gtest_main)
|
||||
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
|
||||
|
||||
@ -73,6 +73,19 @@ void box_cox_zero_lambda(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
at::vec::Vectorized<T> box_cox_nonzero_lambda_impl(
|
||||
at::vec::Vectorized<T> data,
|
||||
at::vec::Vectorized<T> lambda1,
|
||||
at::vec::Vectorized<T> lambda2,
|
||||
at::vec::Vectorized<T> k_eps) {
|
||||
auto sum = data + lambda2;
|
||||
auto max = at::vec::max(sum, k_eps);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
|
||||
auto pow = max.pow(lambda1);
|
||||
return at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void box_cox_nonzero_lambda(
|
||||
int64_t D,
|
||||
@ -88,21 +101,18 @@ void box_cox_nonzero_lambda(
|
||||
auto k_eps_vec = Vec(k_eps);
|
||||
for(; j + VLEN < D; j += VLEN) {
|
||||
auto data = Vec::loadu(data_ptr + j);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j);
|
||||
auto sum = data + lambda2;
|
||||
auto max = at::vec::max(sum, k_eps_vec);
|
||||
auto lambda1 = Vec::loadu(lambda1_ptr + j);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1);
|
||||
auto pow = max.pow(lambda1);
|
||||
auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j);
|
||||
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
|
||||
res.store(out + j);
|
||||
}
|
||||
for ( ;j < D; ++j) {
|
||||
auto sum = data_ptr[j] + lambda2_ptr[j];
|
||||
auto max = std::max(sum, k_eps);
|
||||
auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]);
|
||||
auto pow = std::pow(max, lambda1_ptr[j]);
|
||||
out[j] = pow * lambda_over_1 - lambda_over_1;
|
||||
if (j < D) {
|
||||
auto remaining = D - j;
|
||||
auto data = Vec::loadu(data_ptr + j, remaining);
|
||||
auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining);
|
||||
auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining);
|
||||
auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec);
|
||||
res.store(out + j, remaining);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
@ -46,6 +46,108 @@ These headers are promised to be ABI stable across releases and adhere to a stro
|
||||
Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable`
|
||||
which will handle all the rough edges of the C API for the user.
|
||||
|
||||
## Migrating your kernel to the LibTorch stable ABI
|
||||
|
||||
If you'd like your kernel to be ABI stable with LibTorch, meaning you'd the ability to build for one version and run on another, your kernel must only use the limited stable ABI. This following section goes through some steps of migrating an existing kernel and APIs we imagine you would need to swap over.
|
||||
|
||||
Firstly, instead of registering kernels through `TORCH_LIBRARY`, LibTorch ABI stable kernels must be registered via `STABLE_TORCH_LIBRARY`. Note that, for the time being, implementations registered via `STABLE_TORCH_LIBRARY` must be boxed unlike `TORCH_LIBRARY`. See the simple example below or our docs on [Stack-based APIs](stack-based-apis) for more details. For kernels that are registered via `pybind`, before using the stable ABI, it would be useful to migrate to register them via `TORCH_LIBRARY`.
|
||||
|
||||
While previously your kernels might have included APIs from `<torch/*.h>` (for example, `<torch/all.h>`), they are now limited to including from the 3 categories of headers mentioned above (`torch/csrc/stable/*.h`, `torch/headeronly/*.h` and the stable C headers). This means that your extension should no longer use any utilities from the `at::` or `c10::` namespaces but instead use their replacements in `torch::stable` and `torch::headeronly`. To provide a couple examples of the necessary migrations:
|
||||
- all uses of `at::Tensor` must be replaced with `torch::stable::Tensor`
|
||||
- all uses of `TORCH_CHECK` must be replaced with `STD_TORCH_CHECK`
|
||||
- all uses of `at::kCUDA` must be replaced with `torch::headeronly::kCUDA` etc.
|
||||
- native functions such as `at::pad` must be replaced with `torch::stable::pad`
|
||||
- native functions that are called as Tensor methods (e.g., `Tensor.pad`) must be replaced with the ATen variant through `torch::stable::pad`.
|
||||
|
||||
As mentioned above, the LibTorch stable ABI is still under development. If there is any API or feature you would like to see added to the stable ABI/`torch::headeronly`/`torch::stable`, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues).
|
||||
|
||||
Below is a simple example of migrating an existing kernel that uses `TORCH_LIBRARY` to the stable ABI (`TORCH_STABLE_LIBRARY`). For a larger end to end example you can take a look at the FA3 repository. Specifically the diff between [`flash_api.cpp`](https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api.cpp#L1) and the stable variant [`flash_api_stable.cpp`](https://github.com/Dao-AILab/flash-attention/blob/ad70a007e6287d4f7e766f94bcf2f9a813f20f6b/hopper/flash_api_stable.cpp#L1).
|
||||
|
||||
|
||||
### Original Version with `TORCH_LIBRARY`
|
||||
|
||||
```cpp
|
||||
// original_kernel.cpp - Using TORCH_LIBRARY (not stable ABI)
|
||||
#include <torch/torch.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
namespace myops {
|
||||
|
||||
// Simple kernel that adds a scalar value to each element of a tensor
|
||||
at::Tensor add_scalar(const at::Tensor& input, double scalar) {
|
||||
TORCH_CHECK(input.scalar_type() == at::kFloat, "Input must be float32");
|
||||
|
||||
return input.add(scalar);
|
||||
}
|
||||
|
||||
// Register the operator
|
||||
TORCH_LIBRARY(myops, m) {
|
||||
m.def("add_scalar(Tensor input, float scalar) -> Tensor", &add_scalar);
|
||||
}
|
||||
|
||||
// Register the implementation
|
||||
TORCH_LIBRARY_IMPL(myops, CompositeExplicitAutograd, m) {
|
||||
m.impl("add_scalar", &add_scalar);
|
||||
}
|
||||
|
||||
} // namespace myops
|
||||
```
|
||||
|
||||
### Migrated Version with `STABLE_TORCH_LIBRARY`
|
||||
|
||||
```cpp
|
||||
// stable_kernel.cpp - Using STABLE_TORCH_LIBRARY (stable ABI)
|
||||
|
||||
// (1) Don't include <torch/torch.h> <ATen/ATen.h>
|
||||
// only include APIs from torch/csrc/stable, torch/headeronly and C-shims
|
||||
#include <torch/csrc/stable/library.h>
|
||||
#include <torch/csrc/stable/tensor_struct.h>
|
||||
#include <torch/csrc/stable/ops.h>
|
||||
#include <torch/csrc/stable/stableivalue_conversions.h>
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
#include <torch/headeronly/macros/Macros.h>
|
||||
|
||||
namespace myops {
|
||||
|
||||
// Simple kernel that adds a scalar value to each element of a tensor
|
||||
torch::stable::Tensor add_scalar(const torch::stable::Tensor& input, double scalar) {
|
||||
// (2) use STD_TORCH_CHECK instead of TORCH_CHECK
|
||||
STD_TORCH_CHECK(
|
||||
// (3) use torch::headeronly::kFloat instead of at:kFloat
|
||||
input.scalar_type() == torch::headeronly::kFloat,
|
||||
"Input must be float32");
|
||||
|
||||
// (4) Use stable ops namespace instead of input.add
|
||||
return torch::stable::add(input, scalar);
|
||||
}
|
||||
|
||||
// (5) Add Boxed wrapper required for STABLE_TORCH_LIBRARY
|
||||
void boxed_add_scalar(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
// Extract arguments from stack using `to<T>`
|
||||
auto input = to<torch::stable::Tensor>(stack[0]);
|
||||
auto scalar = to<double>(stack[1]);
|
||||
|
||||
// Call the actual kernel
|
||||
auto result = add_scalar(input, scalar);
|
||||
|
||||
// Put result back on stack using `from()`
|
||||
// Stack slot 0 now holds the return value
|
||||
stack[0] = from(result);
|
||||
}
|
||||
|
||||
// (6) Register the operator using STABLE_TORCH_LIBRARY
|
||||
STABLE_TORCH_LIBRARY(myops, m) {
|
||||
m.def("add_scalar(Tensor input, float scalar) -> Tensor", &boxed_add_scalar);
|
||||
}
|
||||
|
||||
// (7) Register the implementation using STABLE_TORCH_LIBRARY_IMPL
|
||||
STABLE_TORCH_LIBRARY_IMPL(myops, CompositeExplicitAutograd, m) {
|
||||
m.impl("add_scalar", &boxed_add_scalar);
|
||||
}
|
||||
|
||||
} // namespace myops
|
||||
```
|
||||
|
||||
|
||||
## How are objects passed across the ABI boundary when interacting with the dispatcher?
|
||||
|
||||
@ -109,6 +211,7 @@ There are two invariants for the stack:
|
||||
a. When calling a stack-based API, you must give owning references to the calling stack and steal references from the returned stack.
|
||||
b. When registering your function to be called with a stack, you must steal references from your argument stack and push onto the stack new references.
|
||||
|
||||
(stack-based-apis)=
|
||||
### Stack-based APIs
|
||||
|
||||
The above is relevant in two places:
|
||||
|
||||
@ -172,9 +172,9 @@ ignore = [
|
||||
"SIM102", "SIM103", "SIM112", # flake8-simplify code styles
|
||||
"SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason
|
||||
"SIM108", # SIM108 ignored because we prefer if-else-block instead of ternary expression
|
||||
"SIM110",
|
||||
"SIM110", # Checks for for loops that can be replaced with a builtin function, like any or all.
|
||||
"SIM114", # Combine `if` branches using logical `or` operator
|
||||
"SIM115",
|
||||
"SIM115", # Checks for cases where files are opened without using a context manager.
|
||||
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
||||
"SIM117",
|
||||
"SIM118",
|
||||
@ -184,7 +184,6 @@ ignore = [
|
||||
"TC006",
|
||||
# TODO: Remove Python-3.10 specific suppressions
|
||||
"B905",
|
||||
"UP035",
|
||||
]
|
||||
select = [
|
||||
"B",
|
||||
|
||||
33
setup.py
33
setup.py
@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
|
||||
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
|
||||
|
||||
|
||||
def mirror_inductor_external_kernels() -> None:
|
||||
"""
|
||||
Copy external kernels into Inductor so they are importable.
|
||||
"""
|
||||
paths = [
|
||||
(
|
||||
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
|
||||
CWD
|
||||
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
|
||||
),
|
||||
]
|
||||
for new_path, orig_path in paths:
|
||||
# Create the dirs involved in new_path if they don't exist
|
||||
if not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy the files from the orig location to the new location
|
||||
if orig_path.is_file():
|
||||
shutil.copyfile(orig_path, new_path)
|
||||
continue
|
||||
if orig_path.is_dir():
|
||||
if new_path.exists():
|
||||
# copytree fails if the tree exists already, so remove it.
|
||||
shutil.rmtree(new_path)
|
||||
shutil.copytree(orig_path, new_path)
|
||||
continue
|
||||
raise RuntimeError(
|
||||
"Check the file paths in `mirror_inductor_external_kernels()`"
|
||||
)
|
||||
|
||||
|
||||
# ATTENTION: THIS IS AI SLOP
|
||||
def extract_variant_from_version(version: str) -> str:
|
||||
"""Extract variant from version string, defaulting to 'cpu'."""
|
||||
@ -1615,6 +1646,7 @@ def main() -> None:
|
||||
mirror_files_into_torchgen()
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
@ -1649,6 +1681,7 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_inductor/kernel/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
||||
@ -208,7 +208,7 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)
|
||||
|
||||
state1 = state_dict1["state"]
|
||||
for name in state1.keys():
|
||||
for name in state1:
|
||||
# compare mask
|
||||
assert name in sparsifier2.state
|
||||
assert "mask" in sparsifier2.state[name]
|
||||
|
||||
@ -75,6 +75,7 @@ class TestScheduler(TestCase):
|
||||
|
||||
class TestCubicScheduler(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.model_sparse_config = [
|
||||
{"tensor_fqn": "0.weight", "sparsity_level": 0.8},
|
||||
{"tensor_fqn": "2.weight", "sparsity_level": 0.4},
|
||||
|
||||
@ -119,7 +119,7 @@ class TestBaseSparsifier(TestCase):
|
||||
for idx in range(len(sparsifier0.groups)):
|
||||
mg0 = sparsifier0.groups[idx]
|
||||
mg1 = sparsifier1.groups[idx]
|
||||
for key in mg0.keys():
|
||||
for key in mg0:
|
||||
assert key in mg1
|
||||
if key == "module":
|
||||
# We cannot compare modules as they are different
|
||||
|
||||
@ -11,6 +11,7 @@ from torch.testing._internal.common_utils import IS_LINUX, run_tests, TestCase
|
||||
@unittest.skipIf(not IS_LINUX, "Only works on linux")
|
||||
class TestTorchrun(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self._test_dir = tempfile.mkdtemp(prefix=self.__class__.__name__)
|
||||
|
||||
def tearDown(self):
|
||||
|
||||
@ -10,6 +10,8 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch_v2.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
|
||||
|
||||
82
test/cpp/aoti_abi_check/test_dispatch.cpp
Normal file
82
test/cpp/aoti_abi_check/test_dispatch.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
|
||||
// MY_PRIVATE_CHECK_SELECTIVE_BUILD is a prelude to case block. For
|
||||
// testing, we do nothing:
|
||||
#define MY_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) /* empty */
|
||||
|
||||
#define MY_PRIVATE_CASE_TYPE_USING_HINT(...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
||||
MY_PRIVATE_CHECK_SELECTIVE_BUILD, __VA_ARGS__)
|
||||
|
||||
#define MY_DISPATCH_CASE(...) \
|
||||
THO_DISPATCH_CASE_TMPL(MY_PRIVATE_CASE_TYPE_USING_HINT, __VA_ARGS__)
|
||||
|
||||
// MY_RECORD_KERNEL_FUNCTION_DTYPE is a prelude to switch
|
||||
// statement. For testing, we just avoid unused variable warning:
|
||||
#define MY_RECORD_KERNEL_FUNCTION_DTYPE(DISPATCHNAME, ENUMTYPE) \
|
||||
(void)DISPATCHNAME
|
||||
|
||||
// MY_CHECK_NOT_IMPLEMENTED is called in switch default block. For
|
||||
// testing, we count case mismatches:
|
||||
#define MY_CHECK_NOT_IMPLEMENTED(...) default_count++
|
||||
|
||||
#define MY_DISPATCH_SWITCH(...) \
|
||||
THO_DISPATCH_SWITCH_TMPL( \
|
||||
MY_RECORD_KERNEL_FUNCTION_DTYPE, MY_CHECK_NOT_IMPLEMENTED, __VA_ARGS__)
|
||||
|
||||
// MY_CASE_FUNCTION is called in a case block. For testing, we count
|
||||
// case matches and ensure that scalar_t/index_t type is defined:
|
||||
#define MY_CASE_FUNCTION \
|
||||
[&] { \
|
||||
count++; \
|
||||
scalar_t tmp; \
|
||||
(void)tmp; \
|
||||
}
|
||||
#define MY_INDEX_CASE_FUNCTION \
|
||||
[&] { \
|
||||
count++; \
|
||||
index_t tmp; \
|
||||
(void)tmp; \
|
||||
}
|
||||
|
||||
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
|
||||
|
||||
#define MY_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||
THO_DISPATCH_V2_TMPL( \
|
||||
MY_DISPATCH_SWITCH, \
|
||||
MY_DISPATCH_CASE, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(BODY), \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestDispatchV2, NAME) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t total_count = 0; \
|
||||
int8_t count = 0; \
|
||||
int8_t default_count = 0; \
|
||||
for (ScalarType t : \
|
||||
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
|
||||
total_count++; \
|
||||
MY_DISPATCH_V2(t, "test_my_dispatch_v2", MY_CASE_FUNCTION, __VA_ARGS__); \
|
||||
} \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
EXPECT_EQ(default_count + count, total_count); \
|
||||
}
|
||||
|
||||
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
|
||||
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
|
||||
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
|
||||
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
|
||||
|
||||
#undef DEFINE_ITEM
|
||||
45
test/cpp/aoti_abi_check/test_dispatch_v2.cpp
Normal file
45
test/cpp/aoti_abi_check/test_dispatch_v2.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
|
||||
|
||||
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestThoDispatchV2, NAME) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t total_count = 0; \
|
||||
int8_t count = 0; \
|
||||
int8_t default_count = 0; \
|
||||
for (ScalarType t : \
|
||||
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
|
||||
total_count++; \
|
||||
try { \
|
||||
THO_DISPATCH_V2( \
|
||||
t, \
|
||||
"test_tho_dispatch_v2", \
|
||||
[&] { \
|
||||
count++; \
|
||||
scalar_t tmp; \
|
||||
(void)tmp; \
|
||||
}, \
|
||||
__VA_ARGS__); \
|
||||
} catch (...) { \
|
||||
default_count++; /* counts mismatches */ \
|
||||
} \
|
||||
} \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
EXPECT_EQ(default_count + count, total_count); \
|
||||
}
|
||||
|
||||
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
|
||||
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
|
||||
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
|
||||
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
|
||||
|
||||
#undef DEFINE_ITEM
|
||||
@ -67,13 +67,13 @@ Tensor sgd_out_of_place(
|
||||
|
||||
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = sgd_out_of_place(
|
||||
to<Tensor>(stack[0]),
|
||||
to<Tensor>(stack[1]),
|
||||
float(to<double>(stack[2])),
|
||||
to<double>(stack[3]),
|
||||
to<bool>(stack[4]));
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<Tensor>(stack[1]),
|
||||
float(torch::stable::detail::to<double>(stack[2])),
|
||||
torch::stable::detail::to<double>(stack[3]),
|
||||
torch::stable::detail::to<bool>(stack[4]));
|
||||
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
|
||||
@ -89,8 +89,8 @@ Tensor identity(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = identity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -108,14 +108,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
Tensor my_abs(Tensor t) {
|
||||
const auto num_args = 1;
|
||||
StableIValue stack[num_args];
|
||||
stack[0] = from(t);
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::abs", "", stack);
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -132,21 +132,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
|
||||
|
||||
auto mf = aoti_torch_memory_format_contiguous_format();
|
||||
|
||||
stack[0] = from(t);
|
||||
stack[1] = from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = from(std::nullopt); // layout
|
||||
stack[3] = from(std::optional(device)); // device
|
||||
stack[4] = from(std::optional(false)); // pin_memory
|
||||
stack[5] = from(std::optional(mf)); // memory_format
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = torch::stable::detail::from(std::nullopt); // layout
|
||||
stack[3] = torch::stable::detail::from(std::optional(device)); // device
|
||||
stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory
|
||||
stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format
|
||||
|
||||
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
|
||||
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = from(res);
|
||||
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -159,28 +159,28 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
|
||||
StableIValue stack_exp[1];
|
||||
stack_exp[0] = from(t1);
|
||||
stack_exp[0] = torch::stable::detail::from(t1);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
|
||||
|
||||
StableIValue stack_neg[1];
|
||||
stack_neg[0] = from(t2);
|
||||
stack_neg[0] = torch::stable::detail::from(t2);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
|
||||
|
||||
StableIValue stack_is_leaf[1];
|
||||
stack_is_leaf[0] = from(t3);
|
||||
stack_is_leaf[0] = torch::stable::detail::from(t3);
|
||||
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
|
||||
|
||||
return std::make_tuple(
|
||||
to<Tensor>(stack_exp[0]),
|
||||
to<Tensor>(stack_neg[0]),
|
||||
to<bool>(stack_is_leaf[0]));
|
||||
torch::stable::detail::to<Tensor>(stack_exp[0]),
|
||||
torch::stable::detail::to<Tensor>(stack_neg[0]),
|
||||
torch::stable::detail::to<bool>(stack_is_leaf[0]));
|
||||
}
|
||||
|
||||
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
|
||||
stack[0] = from(std::get<0>(tuple));
|
||||
stack[1] = from(std::get<1>(tuple));
|
||||
stack[2] = from(std::get<2>(tuple));
|
||||
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
|
||||
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
|
||||
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -193,15 +193,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
Tensor neg_exp(Tensor t) {
|
||||
StableIValue stack[1];
|
||||
stack[0] = from(t);
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack);
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = neg_exp(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -214,10 +214,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
Tensor divide_neg_exp(Tensor t) {
|
||||
StableIValue stack_neg[1];
|
||||
stack_neg[0] = from(t);
|
||||
stack_neg[0] = torch::stable::detail::from(t);
|
||||
|
||||
StableIValue stack_exp[1];
|
||||
stack_exp[0] = from(t);
|
||||
stack_exp[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
|
||||
|
||||
@ -225,12 +225,12 @@ Tensor divide_neg_exp(Tensor t) {
|
||||
stack_div[0] = stack_neg[0];
|
||||
stack_div[1] = stack_exp[0];
|
||||
aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div);
|
||||
return to<Tensor>(stack_div[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack_div[0]);
|
||||
}
|
||||
|
||||
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -246,8 +246,8 @@ bool is_contiguous(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
bool res = is_contiguous(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -263,9 +263,9 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
|
||||
}
|
||||
|
||||
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
|
||||
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
|
||||
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_empty_like(Tensor t) {
|
||||
@ -273,8 +273,8 @@ Tensor my_empty_like(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_empty_like(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
bool my_is_cpu(Tensor t) {
|
||||
@ -283,8 +283,8 @@ bool my_is_cpu(Tensor t) {
|
||||
|
||||
|
||||
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_is_cpu(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor fill_infinity(Tensor t) {
|
||||
@ -296,8 +296,8 @@ void boxed_fill_infinity(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = fill_infinity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
@ -310,8 +310,8 @@ void boxed_my_pad(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_pad(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
|
||||
@ -323,11 +323,11 @@ void boxed_my_narrow(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_narrow(
|
||||
to<Tensor>(stack[0]),
|
||||
to<int64_t>(stack[1]),
|
||||
to<int64_t>(stack[2]),
|
||||
to<int64_t>(stack[3]));
|
||||
stack[0] = from(res);
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<int64_t>(stack[1]),
|
||||
torch::stable::detail::to<int64_t>(stack[2]),
|
||||
torch::stable::detail::to<int64_t>(stack[3]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
@ -342,8 +342,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
@ -352,8 +352,8 @@ Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
@ -361,8 +361,8 @@ Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
}
|
||||
|
||||
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_copy_(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<bool>(stack[2]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
Tensor my_clone(Tensor t) {
|
||||
@ -370,8 +370,8 @@ Tensor my_clone(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_clone(to<Tensor>(stack[0]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
|
||||
@ -408,8 +408,8 @@ Tensor my_zero_(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_zero_(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax(Tensor t) {
|
||||
@ -417,8 +417,8 @@ Tensor my_amax(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
@ -426,8 +426,8 @@ Tensor my_amax_vec(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax_vec(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -464,8 +464,8 @@ void boxed_test_default_constructor(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_default_constructor(to<bool>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -478,6 +478,56 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_amax_vec", &boxed_my_amax_vec);
|
||||
}
|
||||
|
||||
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
|
||||
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
|
||||
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
|
||||
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
|
||||
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
}
|
||||
|
||||
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
|
||||
// This function tests that my__foreach_mul can take in std::initializer_lists
|
||||
// in addition to std::vectors.
|
||||
Tensor t1_1 = my_clone(t1);
|
||||
Tensor t1_2 = my_clone(t1);
|
||||
Tensor t2_1 = my_clone(t2);
|
||||
Tensor t2_2 = my_clone(t2);
|
||||
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
|
||||
}
|
||||
|
||||
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
|
||||
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
|
||||
m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
|
||||
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
|
||||
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::accelerator APIs
|
||||
|
||||
#ifdef LAE_USE_CUDA
|
||||
@ -500,8 +550,8 @@ void boxed_test_device_guard(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
|
||||
stack[0] = from(res);
|
||||
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_device_guard_set_index() {
|
||||
@ -520,7 +570,7 @@ void boxed_test_device_guard_set_index(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_device_guard_set_index();
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_stream(int32_t device_index) {
|
||||
@ -536,8 +586,8 @@ void boxed_test_stream(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
|
||||
stack[0] = from(res);
|
||||
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_get_current_device_index() {
|
||||
@ -549,7 +599,7 @@ void boxed_test_get_current_device_index(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_get_current_device_index();
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -565,4 +615,5 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_stream", &boxed_test_stream);
|
||||
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
|
||||
}
|
||||
|
||||
#endif // LAE_USE_CUDA
|
||||
|
||||
@ -333,3 +333,45 @@ def my_new_zeros_dtype_variant(t) -> Tensor:
|
||||
Returns: New zeros tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)
|
||||
|
||||
|
||||
def my__foreach_mul_(tensors, others) -> ():
|
||||
"""
|
||||
Updates tensors to be the result of pointwise multiplying with others.
|
||||
|
||||
Args:
|
||||
tensors: list of tensors
|
||||
others: list of tensors (with the same corresponding shapes as tensors)
|
||||
|
||||
Returns: nothing, tensors is updated in place.
|
||||
"""
|
||||
torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others)
|
||||
|
||||
|
||||
def my__foreach_mul(tensors, others) -> list[Tensor]:
|
||||
"""
|
||||
Returns a list of tensors that are the results of pointwise multiplying
|
||||
tensors and others.
|
||||
|
||||
Args:
|
||||
tensors: list of tensors
|
||||
others: list of tensors (with the same corresponding shapes as tensors)
|
||||
|
||||
Returns: list of multiplied tensors
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others)
|
||||
|
||||
|
||||
def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
|
||||
"""
|
||||
Returns a list of 2 tensors corresponding to the square of the inputs.
|
||||
|
||||
Args:
|
||||
t1: Tensor
|
||||
t2: Tensor
|
||||
|
||||
Returns: list of [t1^2, t2^2]
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default(
|
||||
t1, t2
|
||||
)
|
||||
|
||||
@ -367,6 +367,57 @@ if not IS_WINDOWS:
|
||||
self.assertNotEqual(result.data_ptr(), expected.data_ptr())
|
||||
self.assertEqual(result.stride(), expected.stride())
|
||||
|
||||
def test_my__foreach_mul_(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
N = 5
|
||||
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
tensors_c = [t.clone() for t in tensors]
|
||||
others = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
|
||||
libtorch_agnostic.ops.my__foreach_mul_(tensors, others)
|
||||
expected_values = torch._foreach_mul(tensors_c, others)
|
||||
|
||||
for tensor_t, expected_t in zip(tensors, expected_values):
|
||||
self.assertEqual(tensor_t, expected_t)
|
||||
|
||||
def test_my__foreach_mul(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
N = 5
|
||||
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
others = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
|
||||
result = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
|
||||
expected = torch._foreach_mul(tensors, others)
|
||||
|
||||
for result_t, expected_t in zip(result, expected):
|
||||
self.assertEqual(result_t, expected_t)
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
|
||||
expected = torch._foreach_mul(tensors, others)
|
||||
for result_t, expected_t in zip(cuda_res, expected):
|
||||
self.assertEqual(result_t, expected_t)
|
||||
|
||||
if tensors[0].is_cuda:
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_make_tensor_clones_and_call_foreach(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t1 = torch.rand(2, 5, device=device)
|
||||
t2 = torch.rand(3, 4, device=device)
|
||||
result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2)
|
||||
self.assertEqual(result[0], t1 * t1)
|
||||
self.assertEqual(result[1], t2 * t2)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from backend import get_custom_backend_library_path, Model, to_custom_backend
|
||||
@ -11,6 +10,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
class TestCustomBackend(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Load the library containing the custom backend.
|
||||
self.library_path = get_custom_backend_library_path()
|
||||
torch.ops.load_library(self.library_path)
|
||||
@ -40,14 +40,11 @@ class TestCustomBackend(TestCase):
|
||||
self.test_execute()
|
||||
|
||||
# Save and load.
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.close()
|
||||
torch.jit.save(self.model, f.name)
|
||||
loaded = torch.jit.load(f.name)
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
self.model = loaded
|
||||
self.model = loaded
|
||||
|
||||
# Test execution again.
|
||||
self.test_execute()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -18,6 +17,7 @@ torch.ops.import_module("pointwise")
|
||||
|
||||
class TestCustomOperators(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.library_path = get_custom_op_library_path()
|
||||
ops.load_library(self.library_path)
|
||||
|
||||
@ -143,16 +143,13 @@ def forward(self, arg0_1):
|
||||
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
||||
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
||||
# close the file after creation and try to remove it manually.
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
file.close()
|
||||
model.save(file.name)
|
||||
loaded = torch.jit.load(file.name)
|
||||
finally:
|
||||
os.unlink(file.name)
|
||||
|
||||
output = loaded.forward(torch.ones(5))
|
||||
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
||||
output = loaded.forward(torch.ones(5))
|
||||
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: fsdp"]
|
||||
import functools
|
||||
import os
|
||||
import unittest.mock
|
||||
import unittest
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.test_case import run_tests
|
||||
@ -37,9 +37,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
device = {device_type.type}
|
||||
device = '{device_type.type}'
|
||||
torch.manual_seed(0)
|
||||
model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)])
|
||||
for layer in model:
|
||||
|
||||
@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_transformer(self):
|
||||
"""
|
||||
This tests that replicate works on a transformer model with fully_shard and replicate layers
|
||||
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Shard(dim=0),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_transformer_managed_modules(self):
|
||||
"""
|
||||
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
|
||||
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
replicate_model = replicate(replicate_model)
|
||||
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_tp_device_mesh(self):
|
||||
"""
|
||||
This tests that a user can pass in a device mesh to replicate a module
|
||||
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_train_replicate_fsdp(self):
|
||||
"""
|
||||
Tests that replicate_model has the same behavior as original model when training
|
||||
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(replicate_loss, loss)
|
||||
check_sharded_parity(self, model, replicate_model)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_train_parity_2d_mlp(self):
|
||||
"""
|
||||
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training
|
||||
|
||||
@ -80,7 +80,7 @@ class TestSACILP(TestCase):
|
||||
# postprocessing due to the fact that for ModTracker, the post backward hook
|
||||
# is not being called for modules whose inputs don't require gradients
|
||||
# TODO: fix this in ModTracker and ensure it does not lead to any perf regression
|
||||
if _ModState.POST_BW not in mod_stats.snapshots.keys():
|
||||
if _ModState.POST_BW not in mod_stats.snapshots:
|
||||
mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
|
||||
copy.deepcopy(last_snapshot)
|
||||
)
|
||||
|
||||
@ -16,7 +16,7 @@ from torch.distributed.argparse_util import check_env, env
|
||||
class ArgParseUtilTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# remove any lingering environment variables
|
||||
for e in os.environ.keys():
|
||||
for e in os.environ.keys(): # noqa: SIM118
|
||||
if e.startswith("PET_"):
|
||||
del os.environ[e]
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
class TestMakeCheckpointer(TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
# Create a temporary directory for checkpoints
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
|
||||
@ -161,6 +161,7 @@ class TestCheckpointProcessConfig(TestCase):
|
||||
|
||||
class TestCheckpointProcess(TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
"""Set up common test fixtures."""
|
||||
self.rank_info = RankInfo(
|
||||
global_world_size=1,
|
||||
|
||||
@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
class TestCheckpointReader(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Create a temporary directory for test checkpoints
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
|
||||
@ -52,6 +52,7 @@ class TestCheckpointWriterConfig(TestCase):
|
||||
|
||||
class TestCheckpointWriter(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Create a temporary directory for test checkpoints
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
|
||||
@ -52,6 +52,7 @@ class TestCheckpointer(TestCase):
|
||||
"""Parameterized tests that work with both sync and async checkpointers."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Create a temporary directory for checkpoints
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
@ -397,6 +398,7 @@ class TestAsyncCheckpointerSpecific(TestCase):
|
||||
"""Tests specific to AsyncCheckpointer functionality."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Create a temporary directory for checkpoints
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
|
||||
|
||||
@ -12,6 +12,7 @@ from torch.testing._internal.common_utils import requires_cuda, run_tests, TestC
|
||||
|
||||
class TestDefaultStager(TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
# Create a test state dictionary with various data types
|
||||
self.state_dict = {
|
||||
"model": torch.nn.Linear(10, 5).state_dict(),
|
||||
@ -206,7 +207,7 @@ class TestDefaultStager(TestCase):
|
||||
for i, result in enumerate(staged_results):
|
||||
self.assertIsInstance(result, dict)
|
||||
# Verify the result contains the expected keys
|
||||
for key in state_dicts[i].keys():
|
||||
for key in state_dicts[i]:
|
||||
self.assertIn(key, result)
|
||||
|
||||
stager.close()
|
||||
|
||||
@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
|
||||
"""
|
||||
Saving a dtensor with uneven shards.
|
||||
@ -436,6 +436,7 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_checkpointable_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
@ -498,6 +499,7 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user