mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Compare commits
198 Commits
malfet-pat
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e742f09be | |||
| 8de609d4dd | |||
| dc73ebb8d9 | |||
| b6dff81780 | |||
| d5040b0741 | |||
| 2c78080ec0 | |||
| fe6615e397 | |||
| abf31db2cc | |||
| a4c7856112 | |||
| afb014541b | |||
| a0eb49f5bf | |||
| 1d6ef20ea1 | |||
| b91a2ab892 | |||
| 14a845a4ec | |||
| 5135ace3a3 | |||
| e7c1905837 | |||
| 9cf623a209 | |||
| 06aa3ef3d3 | |||
| 0384104e23 | |||
| 325ec98009 | |||
| 47acdea74a | |||
| 71606b289c | |||
| e342a7509a | |||
| 27ac58bd70 | |||
| 406719c3da | |||
| 957570e4a3 | |||
| 56fa205477 | |||
| 9e7c167f90 | |||
| eeb6c96a89 | |||
| 0b12e49795 | |||
| 87646e5db4 | |||
| 29d6bb79e1 | |||
| c2924bbafa | |||
| a2f109dcc3 | |||
| ba5ffa2dca | |||
| c131e4b390 | |||
| 7fd15aa2bd | |||
| e76ccdda94 | |||
| ef4a7d1d74 | |||
| c45c966031 | |||
| d18c742779 | |||
| 4957ae5838 | |||
| 31d6d3ef5c | |||
| 2325c511e7 | |||
| d865156967 | |||
| fbc0bd2e90 | |||
| 70f5f55abf | |||
| 69ecb562e7 | |||
| 5062abe4e7 | |||
| c7007e7584 | |||
| 09705ca9b2 | |||
| ea6b0b5d0f | |||
| bbf852d87f | |||
| 6392b986e7 | |||
| 32d30d96cf | |||
| 46516efa85 | |||
| 84b2147b85 | |||
| 1727a71cb6 | |||
| fb9e10fe25 | |||
| 4e277e6323 | |||
| ba327b7a5c | |||
| 8eb21304ab | |||
| b83a3f6e87 | |||
| 289b47e657 | |||
| c20308b79e | |||
| 4c41e9bde7 | |||
| 2f5223564e | |||
| 28615a765d | |||
| d1446ad75c | |||
| e401a56b96 | |||
| 22650c89fb | |||
| c62a17a2fb | |||
| 713e289ae7 | |||
| 69784a0dbe | |||
| 3c2409c465 | |||
| 724cd32b0c | |||
| b62935d1a5 | |||
| ccc8c117dc | |||
| 86db4de10f | |||
| 12860892f8 | |||
| 694592ac1e | |||
| 285748e838 | |||
| 97e5c03133 | |||
| 7c9a35c0c4 | |||
| 0dd90a2627 | |||
| a429c5de12 | |||
| 192034c41b | |||
| 2972d471dd | |||
| 625d31b578 | |||
| 5bfce8f345 | |||
| edd611f3b0 | |||
| aded2ebb90 | |||
| 5bda7afa05 | |||
| fe739c2255 | |||
| 37c871999d | |||
| aacbec47a5 | |||
| 0cc9cc6937 | |||
| 341e924981 | |||
| 5a9ae7cefe | |||
| 3d59e8aadf | |||
| 4cf1d1af22 | |||
| 05b8214e6a | |||
| 35d2da32bd | |||
| 0968e74266 | |||
| 57dd6a0656 | |||
| 7318ed627b | |||
| 5b2ad2d5dc | |||
| faba6e205f | |||
| 3261149aa3 | |||
| bd7e18bc57 | |||
| 643b3bc8f3 | |||
| 91b626e2ef | |||
| c24a40536b | |||
| 06b3dd9a9b | |||
| bf8297afe0 | |||
| 3f03f84ce2 | |||
| 8a72188828 | |||
| d325aa1877 | |||
| 7aedf3a576 | |||
| eaf4815c1f | |||
| a913b2bb93 | |||
| 1632876edf | |||
| 0e1f76f77e | |||
| ae67a5a9d3 | |||
| 292bd62c71 | |||
| 0e512ee9f0 | |||
| 31ac764239 | |||
| b228f6d180 | |||
| e678450a69 | |||
| 552c3f3e18 | |||
| 5b36e4e30f | |||
| 8cbd557a44 | |||
| 215937b262 | |||
| c0d8211567 | |||
| 35b0f1979a | |||
| 7f254deda5 | |||
| 42b7121579 | |||
| 89461902e7 | |||
| 7e75d9c1af | |||
| 8b6ea87350 | |||
| 77fdb8fc0b | |||
| b276759fa9 | |||
| 98795824a3 | |||
| 982c3afc1a | |||
| ca97452084 | |||
| f5dc9ed4a7 | |||
| 13cff139b1 | |||
| 642532d99a | |||
| 6905e60009 | |||
| 2a709d3c5c | |||
| 146f74745c | |||
| 3dcb697423 | |||
| 8fc312080c | |||
| 137bb17393 | |||
| 506c102be6 | |||
| 7bd0e7b13a | |||
| 3f6d5c7882 | |||
| 0a45a80a1e | |||
| 20a2bf1848 | |||
| d2d7e2e007 | |||
| 0c3b32cdfe | |||
| b831219577 | |||
| 7a5415fa49 | |||
| 3fcdb0331a | |||
| 7c0318f29e | |||
| 870921f43d | |||
| e094178954 | |||
| 8629531760 | |||
| 7834da3d0c | |||
| 444425428a | |||
| 51a17f064b | |||
| 191402dffc | |||
| 8775c579e1 | |||
| f4af4f029a | |||
| 2e5aa1b8bc | |||
| 654ec31cf8 | |||
| 1b3f9fb167 | |||
| 6d5ff9e966 | |||
| faaa219abc | |||
| 7da62e7e6e | |||
| 007973b088 | |||
| 1f5c39f11f | |||
| 0fca020022 | |||
| 8263baabb1 | |||
| f0b966f925 | |||
| 3ca2fb4e55 | |||
| ec2da971dc | |||
| e4ee616def | |||
| 7f31805e72 | |||
| e3e16adff0 | |||
| c202d5b1fc | |||
| 7e9405da0b | |||
| 79768226c9 | |||
| 08cf8c67c2 | |||
| 2fac214138 | |||
| e830571b16 | |||
| 8ca16467fc | |||
| 864d750662 |
@ -36,11 +36,7 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
;;
|
||||
rocm*)
|
||||
BASE_TARGET=rocm
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -168,6 +168,18 @@ case "$tag" in
|
||||
VISION=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.11-clang12)
|
||||
ANACONDA_PYTHON_VERSION=3.11
|
||||
CLANG_VERSION=12
|
||||
VISION=no
|
||||
TRITON=no
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-clang12)
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
CLANG_VERSION=12
|
||||
VISION=no
|
||||
TRITON=no
|
||||
;;
|
||||
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
|
||||
if [[ $tag =~ "jammy" ]]; then
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
@ -195,9 +207,9 @@ case "$tag" in
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
|
||||
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=11
|
||||
GCC_VERSION=13
|
||||
VISION=yes
|
||||
XPU_VERSION=2025.2
|
||||
NINJA_VERSION=1.9.0
|
||||
@ -248,6 +260,12 @@ case "$tag" in
|
||||
HALIDE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=11
|
||||
PALLAS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.12-triton-cpu)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
@ -369,6 +387,7 @@ docker build \
|
||||
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
|
||||
--build-arg "EXECUTORCH=${EXECUTORCH}" \
|
||||
--build-arg "HALIDE=${HALIDE}" \
|
||||
--build-arg "PALLAS=${PALLAS}" \
|
||||
--build-arg "XPU_VERSION=${XPU_VERSION}" \
|
||||
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
|
||||
--build-arg "ACL=${ACL:-}" \
|
||||
|
||||
1
.ci/docker/ci_commit_pins/jax.txt
Normal file
1
.ci/docker/ci_commit_pins/jax.txt
Normal file
@ -0,0 +1 @@
|
||||
0.8.0
|
||||
40
.ci/docker/common/install_jax.sh
Executable file
40
.ci/docker/common/install_jax.sh
Executable file
@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
# Get the pinned JAX version (same for all CUDA versions)
|
||||
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
|
||||
|
||||
function install_jax_12() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
|
||||
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
|
||||
}
|
||||
|
||||
function install_jax_13() {
|
||||
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
|
||||
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
||||
|
||||
# Verify installation
|
||||
python -c "import jax" # check for errors
|
||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
|
||||
}
|
||||
|
||||
# idiomatic parameter and option handling in sh
|
||||
while test $# -gt 0
|
||||
do
|
||||
case "$1" in
|
||||
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
|
||||
;;
|
||||
13.0|13.0.*) install_jax_13;
|
||||
;;
|
||||
*) echo "bad argument $1"; exit 1
|
||||
;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
@ -9,7 +9,7 @@ set -xe
|
||||
|
||||
function install_ubuntu() {
|
||||
. /etc/os-release
|
||||
if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then
|
||||
if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then
|
||||
echo "Ubuntu version ${VERSION_CODENAME} not supported"
|
||||
exit
|
||||
fi
|
||||
@ -35,25 +35,24 @@ function install_ubuntu() {
|
||||
# The xpu-smi packages
|
||||
apt-get install -y flex bison xpu-smi
|
||||
|
||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||
# Compute and Media Runtimes
|
||||
# Compute and Media Runtimes
|
||||
if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then
|
||||
apt-get install -y \
|
||||
intel-opencl-icd intel-level-zero-gpu level-zero \
|
||||
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
|
||||
# Development Packages
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
|
||||
else # rolling driver
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||
else # jammy
|
||||
apt-get install -y \
|
||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||
fi
|
||||
# Development Packages
|
||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||
|
||||
# Install Intel Support Packages
|
||||
apt-get install -y ${XPU_PACKAGES}
|
||||
@ -66,7 +65,7 @@ function install_ubuntu() {
|
||||
function install_rhel() {
|
||||
. /etc/os-release
|
||||
if [[ "${ID}" == "rhel" ]]; then
|
||||
if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||
if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||
echo "RHEL version ${VERSION_ID} not supported"
|
||||
exit
|
||||
fi
|
||||
@ -147,7 +146,7 @@ function install_sles() {
|
||||
XPU_DRIVER_VERSION=""
|
||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||
# Use GPU driver LTS releases
|
||||
XPU_DRIVER_VERSION="/lts/2350"
|
||||
XPU_DRIVER_VERSION="/lts/2523"
|
||||
fi
|
||||
|
||||
# Default use Intel® oneAPI Deep Learning Essentials 2025.1
|
||||
|
||||
@ -49,11 +49,7 @@ case ${DOCKER_TAG_PREFIX} in
|
||||
fi
|
||||
BASE_TARGET=rocm
|
||||
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
|
||||
;;
|
||||
*)
|
||||
|
||||
@ -87,11 +87,7 @@ case ${image} in
|
||||
MANY_LINUX_VERSION="2_28"
|
||||
DEVTOOLSET_VERSION="11"
|
||||
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||
fi
|
||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
||||
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
|
||||
;;
|
||||
manylinux2_28-builder:xpu)
|
||||
|
||||
@ -143,6 +143,15 @@ COPY ci_commit_pins/halide.txt halide.txt
|
||||
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
|
||||
RUN rm install_halide.sh common_utils.sh halide.txt
|
||||
|
||||
ARG PALLAS
|
||||
ARG CUDA_VERSION
|
||||
# Install JAX with CUDA support (for Pallas)
|
||||
COPY ./common/install_jax.sh install_jax.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
|
||||
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
|
||||
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
|
||||
|
||||
ARG ONNX
|
||||
# Install ONNX dependencies
|
||||
COPY ./common/install_onnx.sh ./common/common_utils.sh ./
|
||||
|
||||
@ -8,9 +8,11 @@ from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
try:
|
||||
from typing import Any, Callable, Required, TypedDict # Python 3.11+
|
||||
from collections.abc import Callable # Python 3.11+
|
||||
from typing import Any, Required, TypedDict
|
||||
except ImportError:
|
||||
from typing import Any, Callable, TypedDict
|
||||
from collections.abc import Callable
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from typing_extensions import Required # Fallback for Python <3.11
|
||||
|
||||
|
||||
@ -824,6 +824,11 @@ test_inductor_halide() {
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_pallas() {
|
||||
python test/run_test.py --include inductor/test_pallas.py --verbose
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_inductor_triton_cpu() {
|
||||
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
|
||||
assert_git_not_dirty
|
||||
@ -1724,6 +1729,8 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||
test_inductor_halide
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
|
||||
test_inductor_pallas
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||
test_inductor_triton_cpu
|
||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||
|
||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
||||
cfbc5c2f1c798991715a6b06bb3ce46478c4487c
|
||||
ccb801b88af136454798b945175c4c87e636ac33
|
||||
|
||||
9
.github/labeler.yml
vendored
9
.github/labeler.yml
vendored
@ -138,7 +138,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -148,7 +149,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/**/*cublas*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
@ -158,7 +160,8 @@
|
||||
- test/test_matmul_cuda.py
|
||||
- test/test_scaled_matmul_cuda.py
|
||||
- test/inductor/test_fp8.py
|
||||
- aten/src/ATen/native/cuda/Blas.cpp
|
||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
||||
- torch/_inductor/kernel/mm.py
|
||||
- test/inductor/test_max_autotune.py
|
||||
- third_party/fbgemm
|
||||
|
||||
1
.github/nitpicks.yml
vendored
1
.github/nitpicks.yml
vendored
@ -10,3 +10,4 @@
|
||||
pathFilter:
|
||||
- 'torch/csrc/inductor/aoti_torch/c/*'
|
||||
- 'torch/csrc/inductor/aoti_torch/generated/*'
|
||||
- 'torch/csrc/stable/c/*'
|
||||
|
||||
6
.github/pytorch-probot.yml
vendored
6
.github/pytorch-probot.yml
vendored
@ -2,8 +2,8 @@ tracking_issue: 24422
|
||||
ciflow_tracking_issue: 64124
|
||||
ciflow_push_tags:
|
||||
- ciflow/b200
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/b200-distributed
|
||||
- ciflow/b200-symm-mem
|
||||
- ciflow/binaries
|
||||
- ciflow/binaries_libtorch
|
||||
- ciflow/binaries_wheel
|
||||
@ -22,6 +22,8 @@ ciflow_push_tags:
|
||||
- ciflow/inductor-perf-test-nightly-xpu
|
||||
- ciflow/inductor-periodic
|
||||
- ciflow/inductor-rocm
|
||||
- ciflow/inductor-rocm-mi200
|
||||
- ciflow/inductor-rocm-mi300
|
||||
- ciflow/linux-aarch64
|
||||
- ciflow/mps
|
||||
- ciflow/nightly
|
||||
@ -33,11 +35,13 @@ ciflow_push_tags:
|
||||
- ciflow/quantization-periodic
|
||||
- ciflow/riscv64
|
||||
- ciflow/rocm
|
||||
- ciflow/rocm-mi200
|
||||
- ciflow/rocm-mi300
|
||||
- ciflow/rocm-mi355
|
||||
- ciflow/rocm-navi31
|
||||
- ciflow/s390
|
||||
- ciflow/slow
|
||||
- ciflow/slow-rocm-mi200
|
||||
- ciflow/torchbench
|
||||
- ciflow/triton_binaries
|
||||
- ciflow/trunk
|
||||
|
||||
3
.github/scripts/delete_old_branches.py
vendored
3
.github/scripts/delete_old_branches.py
vendored
@ -1,10 +1,11 @@
|
||||
# Delete old branches
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from github_utils import gh_fetch_json_dict, gh_graphql
|
||||
from gitutils import GitRepo
|
||||
|
||||
3
.github/scripts/filter_test_configs.py
vendored
3
.github/scripts/filter_test_configs.py
vendored
@ -8,10 +8,11 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from functools import cache
|
||||
from logging import info
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import yaml
|
||||
|
||||
3
.github/scripts/get_workflow_job_id.py
vendored
3
.github/scripts/get_workflow_job_id.py
vendored
@ -11,7 +11,8 @@ import sys
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
|
||||
3
.github/scripts/github_utils.py
vendored
3
.github/scripts/github_utils.py
vendored
@ -3,8 +3,9 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
from typing import Any, cast, Optional, Union
|
||||
from urllib.error import HTTPError
|
||||
from urllib.parse import quote
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
4
.github/scripts/gitutils.py
vendored
4
.github/scripts/gitutils.py
vendored
@ -4,10 +4,10 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Callable, Iterator
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||
from typing import Any, cast, Optional, TypeVar, Union
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
4
.github/scripts/trymerge.py
vendored
4
.github/scripts/trymerge.py
vendored
@ -17,12 +17,12 @@ import re
|
||||
import time
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||
from typing import Any, cast, NamedTuple, Optional
|
||||
from warnings import warn
|
||||
|
||||
import yaml
|
||||
|
||||
7
.github/workflows/docker-builds.yml
vendored
7
.github/workflows/docker-builds.yml
vendored
@ -56,6 +56,8 @@ jobs:
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
|
||||
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
|
||||
pytorch-linux-jammy-py3.10-clang12,
|
||||
pytorch-linux-jammy-py3.11-clang12,
|
||||
pytorch-linux-jammy-py3.12-clang12,
|
||||
pytorch-linux-jammy-py3.13-clang12,
|
||||
pytorch-linux-jammy-py3.14-clang12,
|
||||
pytorch-linux-jammy-rocm-n-py3,
|
||||
@ -65,9 +67,10 @@ jobs:
|
||||
pytorch-linux-jammy-py3.10-gcc11,
|
||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3.12-halide,
|
||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
|
||||
pytorch-linux-jammy-xpu-n-1-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3,
|
||||
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-noble-xpu-n-py3,
|
||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
||||
pytorch-linux-jammy-py3-clang18-asan,
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
|
||||
@ -83,8 +83,8 @@ jobs:
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
@ -137,7 +137,7 @@ jobs:
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||
|
||||
@ -7,7 +7,7 @@ on:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/inductor-rocm/*
|
||||
- ciflow/inductor-rocm-mi200/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
1
.github/workflows/inductor-rocm-mi300.yml
vendored
1
.github/workflows/inductor-rocm-mi300.yml
vendored
@ -7,6 +7,7 @@ on:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/inductor-rocm/*
|
||||
- ciflow/inductor-rocm-mi300/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
|
||||
26
.github/workflows/inductor-unittest.yml
vendored
26
.github/workflows/inductor-unittest.yml
vendored
@ -81,6 +81,32 @@ jobs:
|
||||
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-build:
|
||||
name: inductor-pallas-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
|
||||
cuda-arch-list: '8.9'
|
||||
runner: linux.8xlarge.memory
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
inductor-pallas-test:
|
||||
name: inductor-pallas-test
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: inductor-pallas-build
|
||||
with:
|
||||
build-environment: linux-jammy-py3.12-gcc11
|
||||
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
inductor-triton-cpu-build:
|
||||
name: inductor-triton-cpu-build
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
1
.github/workflows/periodic-rocm-mi200.yml
vendored
1
.github/workflows/periodic-rocm-mi200.yml
vendored
@ -11,7 +11,6 @@ on:
|
||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||
push:
|
||||
tags:
|
||||
- ciflow/periodic/*
|
||||
- ciflow/periodic-rocm-mi200/*
|
||||
branches:
|
||||
- release/*
|
||||
|
||||
1
.github/workflows/periodic-rocm-mi300.yml
vendored
1
.github/workflows/periodic-rocm-mi300.yml
vendored
@ -11,6 +11,7 @@ on:
|
||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||
push:
|
||||
tags:
|
||||
- ciflow/periodic/*
|
||||
- ciflow/periodic-rocm-mi300/*
|
||||
branches:
|
||||
- release/*
|
||||
|
||||
8
.github/workflows/pull.yml
vendored
8
.github/workflows/pull.yml
vendored
@ -342,16 +342,16 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-build:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-build:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
||||
# sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },
|
||||
|
||||
@ -5,7 +5,7 @@ on:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/rocm/*
|
||||
- ciflow/rocm-mi200/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: 29 8 * * * # about 1:29am PDT
|
||||
1
.github/workflows/rocm-mi300.yml
vendored
1
.github/workflows/rocm-mi300.yml
vendored
@ -6,6 +6,7 @@ on:
|
||||
- main
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/rocm/*
|
||||
- ciflow/rocm-mi300/*
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
|
||||
81
.github/workflows/slow-rocm-mi200.yml
vendored
Normal file
81
.github/workflows/slow-rocm-mi200.yml
vendored
Normal file
@ -0,0 +1,81 @@
|
||||
# This workflow is dedicated to host slow jobs that are run only periodically because
|
||||
# they are too slow to run in every commit. The list of slow tests can be found in
|
||||
# https://github.com/pytorch/test-infra/blob/generated-stats/stats/slow-tests.json
|
||||
name: slow-rocm-mi200
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- release/*
|
||||
tags:
|
||||
- ciflow/slow/*
|
||||
- ciflow/slow-rocm-mi200/*
|
||||
schedule:
|
||||
- cron: 0 */3 * * *
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
llm-td:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: before-test
|
||||
uses: ./.github/workflows/llm_td_retrieval.yml
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
target-determination:
|
||||
name: before-test
|
||||
uses: ./.github/workflows/target_determination.yml
|
||||
needs: llm-td
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
30
.github/workflows/slow.yml
vendored
30
.github/workflows/slow.yml
vendored
@ -105,36 +105,6 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-build:
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-rocm-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-jammy-rocm-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-jammy-rocm-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-rocm-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3_10-clang18-asan-build:
|
||||
name: linux-jammy-py3.10-clang18-asan
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
5
.github/workflows/upload-test-stats.yml
vendored
5
.github/workflows/upload-test-stats.yml
vendored
@ -11,15 +11,16 @@ on:
|
||||
- inductor
|
||||
- unstable
|
||||
- slow
|
||||
- slow-rocm-mi200
|
||||
- unstable-periodic
|
||||
- inductor-periodic
|
||||
- rocm
|
||||
- rocm-mi200
|
||||
- rocm-mi300
|
||||
- rocm-mi355
|
||||
- inductor-micro-benchmark
|
||||
- inductor-micro-benchmark-x86
|
||||
- inductor-cu124
|
||||
- inductor-rocm
|
||||
- inductor-rocm-mi200
|
||||
- inductor-rocm-mi300
|
||||
- mac-mps
|
||||
- linux-aarch64
|
||||
|
||||
20
.github/workflows/xpu.yml
vendored
20
.github/workflows/xpu.yml
vendored
@ -47,15 +47,15 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-build:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-build:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
sync-tag: linux-xpu-n-build
|
||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
||||
runner: linux.c7i.12xlarge
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -74,17 +74,17 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-xpu-n-py3_10-test:
|
||||
name: linux-jammy-xpu-n-py3.10
|
||||
linux-noble-xpu-n-py3_10-test:
|
||||
name: linux-noble-xpu-n-py3.10
|
||||
uses: ./.github/workflows/_xpu-test.yml
|
||||
needs: linux-jammy-xpu-n-py3_10-build
|
||||
needs: linux-noble-xpu-n-py3_10-build
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
with:
|
||||
build-environment: linux-jammy-xpu-n-py3.10
|
||||
docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||
build-environment: linux-noble-xpu-n-py3.10
|
||||
docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
windows-xpu-n-1-build:
|
||||
|
||||
@ -143,7 +143,8 @@ init_command = [
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
|
||||
'numpy==2.3.4 ; python_version >= "3.14"',
|
||||
'expecttest==0.3.0',
|
||||
'pyrefly==0.36.2',
|
||||
'sympy==1.13.3',
|
||||
@ -1401,7 +1402,7 @@ init_command = [
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'usort==1.0.8.post1',
|
||||
'isort==6.0.1',
|
||||
'ruff==0.13.1', # sync with RUFF
|
||||
'ruff==0.14.4', # sync with RUFF
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
@ -1536,7 +1537,7 @@ init_command = [
|
||||
'python3',
|
||||
'tools/linter/adapters/pip_init.py',
|
||||
'--dry-run={{DRYRUN}}',
|
||||
'ruff==0.13.1', # sync with PYFMT
|
||||
'ruff==0.14.4', # sync with PYFMT
|
||||
]
|
||||
is_formatter = true
|
||||
|
||||
|
||||
@ -210,8 +210,12 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
||||
/test/inductor/test_flex_attention.py @drisspg
|
||||
/test/inductor/test_flex_decoding.py @drisspg
|
||||
|
||||
# Low Precision GEMMs
|
||||
# Low Precision & Grouped GEMMs
|
||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
|
||||
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
|
||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
||||
|
||||
@ -174,6 +174,12 @@ class TORCH_API Context {
|
||||
static long versionCuDNN() {
|
||||
return detail::getCUDAHooks().versionCuDNN();
|
||||
}
|
||||
static long versionRuntimeCuDNN() {
|
||||
return detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
}
|
||||
static long versionCuDNNFrontend() {
|
||||
return detail::getCUDAHooks().versionCuDNNFrontend();
|
||||
}
|
||||
static bool hasCuSOLVER() {
|
||||
return detail::getCUDAHooks().hasCuSOLVER();
|
||||
}
|
||||
|
||||
@ -94,6 +94,11 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
||||
}
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
#include <c10/util/complex.h>
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
|
||||
#ifdef __CUDACC__
|
||||
#include <cuda.h> // For CUDA_VERSION
|
||||
@ -61,12 +62,9 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
case enum_type: { \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
||||
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_CASE(enum_type, ...) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
||||
@ -95,14 +93,6 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
||||
return __VA_ARGS__(); \
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// The AT_DISPATCH_* family of macros provides the ability to
|
||||
// conveniently generate specializations of a kernel over all of the
|
||||
// dtypes we care about in PyTorch. We call it "dispatch" because
|
||||
@ -190,27 +180,13 @@ inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||
// but we're just being safe (and it doesn't hurt.) Note we must
|
||||
// use it to shut up warnings about unused store.
|
||||
|
||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const auto& the_type = TYPE; \
|
||||
constexpr const char* at_dispatch_name = NAME; \
|
||||
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
||||
switch (_st) { \
|
||||
__VA_ARGS__ \
|
||||
default: \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED( \
|
||||
false, \
|
||||
'"', \
|
||||
at_dispatch_name, \
|
||||
"\" not implemented for '", \
|
||||
toString(_st), \
|
||||
"'"); \
|
||||
} \
|
||||
C10_DIAGNOSTIC_POP() \
|
||||
}()
|
||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||
THO_DISPATCH_SWITCH_TMPL( \
|
||||
RECORD_KERNEL_FUNCTION_DTYPE, \
|
||||
TORCH_CHECK_NOT_IMPLEMENTED, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||
|
||||
@ -1,3 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
|
||||
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
// This is a new implementation of the AT_DISPATCH macro family from
|
||||
@ -74,41 +79,19 @@
|
||||
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
||||
// relied on GPT4 to help me get it right.
|
||||
|
||||
// Public API macros
|
||||
|
||||
// See documentation above
|
||||
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
||||
|
||||
// This macro lets you pass an arbitrary expression that may contain internal
|
||||
// commas to another macro without having the commas causing the expression
|
||||
// to be interpreted as being multiple arguments
|
||||
#define AT_WRAP(...) __VA_ARGS__
|
||||
|
||||
#define AT_FLOAT8_TYPES \
|
||||
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
||||
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
|
||||
|
||||
#define AT_INTEGRAL_TYPES \
|
||||
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
||||
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
||||
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
||||
#define AT_INTEGRAL_TYPES_V2 \
|
||||
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
||||
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
||||
// NB: not *actually* all types
|
||||
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||
#define AT_ALL_TYPES_AND_COMPLEX \
|
||||
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
||||
|
||||
// Helper macros
|
||||
THO_DISPATCH_V2_TMPL( \
|
||||
AT_DISPATCH_SWITCH, \
|
||||
AT_DISPATCH_CASE, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(BODY), \
|
||||
__VA_ARGS__)
|
||||
|
||||
// Unused helper macros, kept for BC:
|
||||
#define AT_AP_VAR(N, T, ...) \
|
||||
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
||||
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
||||
#define AT_CONCAT_AUX(a, b) a##b
|
||||
#define AT_EXPAND(X) X
|
||||
|
||||
// Ensure we never have too many scalar types for the expansion here to
|
||||
// support. To bump this, you must regenerate the macros below.
|
||||
@ -119,12 +102,6 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
|
||||
|
||||
num_args = 60
|
||||
|
||||
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
||||
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
||||
|
||||
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
||||
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
||||
|
||||
for i in range(1, num_args+1):
|
||||
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
||||
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
||||
@ -135,8 +112,6 @@ for i in range(1, num_args+1):
|
||||
// Begin generated code
|
||||
// clang-format off
|
||||
|
||||
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
|
||||
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
||||
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
||||
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
||||
|
||||
@ -226,8 +226,8 @@ template <
|
||||
typename B = HostBlock<S>>
|
||||
struct CachingHostAllocatorImpl {
|
||||
virtual ~CachingHostAllocatorImpl() {
|
||||
active_ = false;
|
||||
if (pinned_use_background_threads()) {
|
||||
if (active_) {
|
||||
active_ = false;
|
||||
getBackgroundThreadPool()->waitWorkComplete();
|
||||
}
|
||||
}
|
||||
@ -260,6 +260,7 @@ struct CachingHostAllocatorImpl {
|
||||
if (pinned_use_background_threads()) {
|
||||
// Launch the background thread and process events in a loop.
|
||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||
active_ = true;
|
||||
getBackgroundThreadPool()->run([&]() {
|
||||
while (active_) {
|
||||
process_events();
|
||||
@ -683,9 +684,9 @@ struct CachingHostAllocatorImpl {
|
||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||
|
||||
// Indicates whether the object is active.
|
||||
// Indicates whether the event-processing thread pool is active.
|
||||
// Set to false in the destructor to signal background threads to stop.
|
||||
std::atomic<bool> active_{true};
|
||||
std::atomic<bool> active_{false};
|
||||
protected:
|
||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||
};
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||
#include <cudnn_frontend.h>
|
||||
#endif
|
||||
|
||||
#if AT_MAGMA_ENABLED()
|
||||
@ -351,6 +352,26 @@ long CUDAHooks::versionCuDNN() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionRuntimeCuDNN() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
#ifndef USE_STATIC_CUDNN
|
||||
return cudnnGetVersion();
|
||||
#else
|
||||
return CUDNN_VERSION;
|
||||
#endif
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionCuDNNFrontend() const {
|
||||
#if AT_CUDNN_ENABLED()
|
||||
return CUDNN_FRONTEND_VERSION;
|
||||
#else
|
||||
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
|
||||
#endif
|
||||
}
|
||||
|
||||
long CUDAHooks::versionMIOpen() const {
|
||||
#if AT_ROCM_ENABLED()
|
||||
return MIOPEN_VERSION_MAJOR * 10000 +
|
||||
|
||||
@ -49,6 +49,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCUDART() const override;
|
||||
long versionCUDART() const override;
|
||||
long versionCuDNN() const override;
|
||||
long versionRuntimeCuDNN() const override;
|
||||
long versionCuDNNFrontend() const override;
|
||||
long versionMIOpen() const override;
|
||||
std::string showConfig() const override;
|
||||
double batchnormMinEpsilonCuDNN() const override;
|
||||
|
||||
@ -174,6 +174,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionRuntimeCuDNN() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionCuDNNFrontend() const {
|
||||
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual long versionMIOpen() const {
|
||||
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
@ -157,6 +157,8 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
||||
DispatchKey::Negative,
|
||||
DispatchKey::Conjugate,
|
||||
DispatchKey::XLA,
|
||||
DispatchKey::XPU,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::CUDA,
|
||||
DispatchKey::CPU,
|
||||
DispatchKey::PrivateUse1,
|
||||
|
||||
@ -409,7 +409,7 @@ struct ConvParams {
|
||||
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
|
||||
return false;
|
||||
}
|
||||
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
|
||||
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
// broken on cuDNN 9.8 - 9.14
|
||||
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
||||
@ -453,7 +453,7 @@ struct ConvParams {
|
||||
}
|
||||
// native kernel doesn't support 64-bit non-splittable case
|
||||
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
|
||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
|
||||
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
|
||||
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
|
||||
if (cudnn_version < 0 || cudnn_version > 91000) {
|
||||
|
||||
@ -884,6 +884,69 @@ struct type_specialized_kernel_launcher {
|
||||
}
|
||||
};
|
||||
|
||||
template <int arg_index>
|
||||
struct type_specialized_broadcast_kernel_launcher {
|
||||
template <
|
||||
typename func_t,
|
||||
typename array_t,
|
||||
typename dtypes_t,
|
||||
typename calc_t>
|
||||
static void apply(
|
||||
int64_t numel,
|
||||
func_t f,
|
||||
array_t data,
|
||||
dtypes_t dtypes,
|
||||
calc_t offset_calc) {
|
||||
using traits = function_traits<func_t>;
|
||||
using ret_t = typename traits::result_type;
|
||||
using arg0_t = typename traits::template arg<0>::type;
|
||||
using arg1_t = typename traits::template arg<1>::type;
|
||||
if (dtypes[0] == rt_binary_specializations[arg_index][0] &&
|
||||
dtypes[1] == rt_binary_specializations[arg_index][1] &&
|
||||
dtypes[2] == rt_binary_specializations[arg_index][2]) {
|
||||
using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0]>;
|
||||
using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1]>;
|
||||
using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2]>;
|
||||
constexpr int grp_sz = 128;
|
||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
||||
if (unrl) {
|
||||
auto offsets0 = offset_calc.get(idx);
|
||||
auto offsets1 = offset_calc.get(idx + grp_sz);
|
||||
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
|
||||
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
|
||||
void* out0 = data[0] + offsets0[0];
|
||||
void* out1 = data[0] + offsets1[0];
|
||||
void* out2 = data[0] + offsets2[0];
|
||||
void* out3 = data[0] + offsets3[0];
|
||||
auto u = c10::load<arg0_cpp_t>(data[1] + offsets0[1]);
|
||||
auto v = c10::load<arg1_cpp_t>(data[2] + offsets0[2]);
|
||||
ret_t result0 = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
|
||||
auto u1 = c10::load<arg0_cpp_t>(data[1] + offsets1[1]);
|
||||
auto v1 = c10::load<arg1_cpp_t>(data[2]+ offsets1[2]);
|
||||
ret_t result1 = f(c10::convert<arg0_t>(u1), c10::convert<arg1_t>(v1));
|
||||
auto u2 = c10::load<arg0_cpp_t>(data[1] + offsets2[1]);
|
||||
auto v2 = c10::load<arg1_cpp_t>(data[2] + offsets2[2]);
|
||||
ret_t result2 = f(c10::convert<arg0_t>(u2), c10::convert<arg1_t>(v2));
|
||||
auto u3 = c10::load<arg0_cpp_t>(data[1] + offsets3[1]);
|
||||
auto v3 = c10::load<arg1_cpp_t>(data[2] + offsets3[2]);
|
||||
ret_t result3 = f(c10::convert<arg0_t>(u3), c10::convert<arg1_t>(v3));
|
||||
*(ret_cpp_t*)out0 = c10::convert<ret_cpp_t>(result0);
|
||||
*(ret_cpp_t*)out1 = c10::convert<ret_cpp_t>(result1);
|
||||
*(ret_cpp_t*)out2 = c10::convert<ret_cpp_t>(result2);
|
||||
*(ret_cpp_t*)out3 = c10::convert<ret_cpp_t>(result3);
|
||||
} else {
|
||||
auto offsets = offset_calc.get(idx);
|
||||
void* out = data[0] + offsets[0];
|
||||
auto u = c10::load<arg0_cpp_t>(data[1] + offsets[1]);
|
||||
auto v = c10::load<arg1_cpp_t>(data[2] + offsets[2]);
|
||||
ret_t result = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
|
||||
*(ret_cpp_t*)out = c10::convert<ret_cpp_t>(result);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
#endif
|
||||
|
||||
@ -1002,6 +1065,32 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
|
||||
}
|
||||
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
||||
#ifdef USE_ROCM
|
||||
if (check_binary_rt_types_for_specialization(iter)) {
|
||||
// constexpr to reduce the amount of kernels generated for
|
||||
// broadcast elementwise with mexed dtypes and limit which functors are actually
|
||||
// applied to the load and store at compile time.
|
||||
using func_tuple = typename traits::ArgsTuple;
|
||||
if constexpr (
|
||||
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
|
||||
check_binary_functor_types_for_specialization<
|
||||
func_tuple,
|
||||
float,
|
||||
float,
|
||||
traits::arity,
|
||||
/*arg_num=*/0>::check()) {
|
||||
memory::detail::static_unroll<
|
||||
type_specialized_broadcast_kernel_launcher,
|
||||
rt_binary_specializations.size()>::with_args(
|
||||
numel,
|
||||
f,
|
||||
data,
|
||||
dtypes,
|
||||
offset_calc
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int grp_sz = 128;
|
||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
||||
if (unrl) {
|
||||
|
||||
@ -133,7 +133,7 @@ at::Tensor quantized_convolution(
|
||||
// supported in conv.
|
||||
mask_weight = weight_zero_points.numel() > 1 ? 1 : 0;
|
||||
if (groups > 1 && weight_zero_points.numel() > 1)
|
||||
mask_weight = (2 ^ 0) | (2 ^ 1); // 2^0 (group) | 2^1 (output channel)
|
||||
mask_weight = (1 << 0) | (1 << 1); // 2^0 (group) | 2^1 (output channel)
|
||||
dnnl::primitive_attr pattr;
|
||||
|
||||
bool src_need_zp = (act_zero_point != 0);
|
||||
|
||||
@ -141,6 +141,9 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
|
||||
};
|
||||
|
||||
MPSStream* stream = at::mps::getCurrentMPSStream();
|
||||
if (result.numel() == 0) {
|
||||
return result;
|
||||
}
|
||||
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
|
||||
|
||||
@autoreleasepool {
|
||||
|
||||
@ -2803,7 +2803,7 @@
|
||||
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: floor_divide_out
|
||||
CPU, CUDA, MPS, MTIA: floor_divide_out
|
||||
SparseCPU, SparseCUDA, SparseMPS: floor_divide_out_sparse_zerodim
|
||||
|
||||
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
@ -4383,7 +4383,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: mv
|
||||
SparseCPU, SparseCUDA: mv_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: mv_sparse
|
||||
|
||||
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
|
||||
@ -478,7 +478,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
|
||||
const auto s_k = params.key.sym_size(2);
|
||||
const auto d_qk = params.query.sym_size(3);
|
||||
const auto d_v = params.value.sym_size(3);
|
||||
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
|
||||
long cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
if (cudnn_version < 8903) {
|
||||
if (debug) {
|
||||
TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
|
||||
@ -709,7 +709,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
|
||||
return false;
|
||||
#endif
|
||||
#if defined(CUDNN_VERSION)
|
||||
static auto cudnn_version = cudnnGetVersion();
|
||||
static auto cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
|
||||
if (params.dropout > 0.0 && cudnn_version > 91100 && cudnn_version < 91400) {
|
||||
if (debug) {
|
||||
TORCH_WARN(CUDNN_VERSION, " cuDNN version does not support droppout in SDPA (9.11 - 9.13).");
|
||||
|
||||
@ -52,19 +52,18 @@ def test_sparse_coo_and_csr(m, n, k, nnz, test_count):
|
||||
start.record()
|
||||
coo.matmul(mat)
|
||||
stop.record()
|
||||
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
coo_mean_time = sum(times) / len(times)
|
||||
coo_mean_time = sum(times) / len(times)
|
||||
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
csr.matmul(mat)
|
||||
stop.record()
|
||||
times.append(start.elapsed_time(stop))
|
||||
times = []
|
||||
for _ in range(test_count):
|
||||
start.record()
|
||||
csr.matmul(mat)
|
||||
stop.record()
|
||||
times.append(start.elapsed_time(stop))
|
||||
|
||||
csr_mean_time = sum(times) / len(times)
|
||||
csr_mean_time = sum(times) / len(times)
|
||||
|
||||
return coo_mean_time, csr_mean_time
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/macros/Export.h>
|
||||
#include <optional>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
@ -15,7 +17,8 @@ struct C10_API AutogradState {
|
||||
bool inference_mode,
|
||||
bool fw_grad_mode,
|
||||
bool multithreading_enabled)
|
||||
: grad_mode_(grad_mode),
|
||||
: graph_exec_group_(std::nullopt),
|
||||
grad_mode_(grad_mode),
|
||||
inference_mode_(inference_mode),
|
||||
fw_grad_mode_(fw_grad_mode),
|
||||
multithreading_enabled_(multithreading_enabled),
|
||||
@ -41,6 +44,10 @@ struct C10_API AutogradState {
|
||||
view_replay_enabled_ = view_replay_enabled;
|
||||
}
|
||||
|
||||
void set_graph_exec_group(std::optional<SafePyObject> group) {
|
||||
graph_exec_group_ = std::move(group);
|
||||
}
|
||||
|
||||
bool get_grad_mode() const {
|
||||
return grad_mode_;
|
||||
}
|
||||
@ -61,7 +68,12 @@ struct C10_API AutogradState {
|
||||
return view_replay_enabled_;
|
||||
}
|
||||
|
||||
const std::optional<SafePyObject>& get_graph_exec_group() const {
|
||||
return graph_exec_group_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::optional<SafePyObject> graph_exec_group_;
|
||||
bool grad_mode_ : 1;
|
||||
bool inference_mode_ : 1;
|
||||
bool fw_grad_mode_ : 1;
|
||||
|
||||
@ -96,6 +96,10 @@ struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Return the free memory size and total memory size in bytes for the
|
||||
// specified device.
|
||||
virtual std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
|
||||
@ -345,6 +345,13 @@ class CUDAAllocator : public DeviceAllocator {
|
||||
c10::DeviceIndex device,
|
||||
std::shared_ptr<AllocatorState> pps) = 0;
|
||||
virtual std::string name() = 0;
|
||||
std::pair<size_t, size_t> getMemoryInfo(c10::DeviceIndex device) override {
|
||||
c10::DeviceGuard device_guard({at::kCUDA, device});
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
C10_CUDA_CHECK(cudaMemGetInfo(&free, &total));
|
||||
return {free, total};
|
||||
}
|
||||
};
|
||||
|
||||
// Allocator object, statically initialized
|
||||
|
||||
@ -66,6 +66,15 @@ def define_targets(rules):
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_test(
|
||||
name = "util/nofatal_test",
|
||||
srcs = ["util/nofatal_test.cpp"],
|
||||
deps = [
|
||||
"//c10/util:base",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
rules.cc_test(
|
||||
name = "util/ssize_test",
|
||||
srcs = ["util/ssize_test.cpp"],
|
||||
|
||||
53
c10/test/util/nofatal_test.cpp
Normal file
53
c10/test/util/nofatal_test.cpp
Normal file
@ -0,0 +1,53 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Logging.h>
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
inline void expectThrowsEq(T&& fn, const char* expected_msg) {
|
||||
try {
|
||||
std::forward<T>(fn)();
|
||||
} catch (const c10::Error& e) {
|
||||
EXPECT_TRUE(
|
||||
std::string(e.what_without_backtrace()).find(expected_msg) !=
|
||||
std::string::npos);
|
||||
return;
|
||||
}
|
||||
ADD_FAILURE() << "Expected to throw exception with message \"" << expected_msg
|
||||
<< "\" but didn't throw";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(NofatalTest, TorchCheckComparisons) {
|
||||
// quick make sure that no-op works as expected
|
||||
TORCH_CHECK_EQ(1, 1) << "i am a silly message " << 1;
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_EQ(1, 2) << "i am a silly message " << 1; },
|
||||
"Check failed: 1 == 2 (1 vs. 2). i am a silly message 1");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_NE(2, 2); }, "Check failed: 2 != 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_LT(2, 2); }, "Check failed: 2 < 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_LE(3, 2); }, "Check failed: 3 <= 2 (3 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_GT(2, 2); }, "Check failed: 2 > 2 (2 vs. 2).");
|
||||
expectThrowsEq(
|
||||
[]() { TORCH_CHECK_GE(2, 3); }, "Check failed: 2 >= 3 (2 vs. 3).");
|
||||
expectThrowsEq(
|
||||
[]() {
|
||||
void* p = nullptr;
|
||||
TORCH_CHECK_NOTNULL(p);
|
||||
},
|
||||
"Check failed: 'p' must be non NULL.");
|
||||
|
||||
#if GTEST_HAS_DEATH_TEST
|
||||
#ifndef NDEBUG
|
||||
// if dbg build, DCHECK should result in deth
|
||||
EXPECT_DEATH(TORCH_DCHECK_EQ(1, 2), "Check failed");
|
||||
#else
|
||||
TORCH_DCHECK_EQ(1, 2); // no-op
|
||||
#endif
|
||||
#endif // GTEST_HAS_DEATH_TEST
|
||||
}
|
||||
@ -702,6 +702,98 @@ namespace c10::detail {
|
||||
#define TORCH_CHECK_ARG(cond, argN, ...) \
|
||||
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)
|
||||
|
||||
#ifndef FATAL_IF
|
||||
#ifdef C10_USE_GLOG
|
||||
#define FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::google::GLOG_FATAL) \
|
||||
.stream()
|
||||
#else
|
||||
#define FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream()
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef NON_FATAL_IF
|
||||
#ifdef C10_USE_GLOG
|
||||
#define NON_FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger( \
|
||||
__FILE__, __LINE__, ::google::GLOG_FATAL, false) \
|
||||
.stream()
|
||||
#else
|
||||
#define NON_FATAL_IF(condition) \
|
||||
condition ? (void)0 \
|
||||
: ::c10::LoggerVoidify() & \
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL, false) \
|
||||
.stream()
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Binary comparison check macros
|
||||
#define TORCH_CHECK_OP(val1, val2, op) \
|
||||
NON_FATAL_IF(((val1)op(val2))) \
|
||||
<< "Check failed: " #val1 " " #op " " #val2 " (" << (val1) << " vs. " \
|
||||
<< (val2) << "). "
|
||||
|
||||
#define TORCH_DCHECK_OP(val1, val2, op) \
|
||||
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
|
||||
<< (val1) << " vs. " << (val2) << "). "
|
||||
|
||||
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
|
||||
// Debug versions of TORCH_CHECK_OP macros
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_EQ(val1, val2) TORCH_DCHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) TORCH_DCHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) TORCH_DCHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) TORCH_DCHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) TORCH_DCHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) TORCH_DCHECK_OP(val1, val2, >)
|
||||
#else // !NDEBUG
|
||||
// Optimized versions - generate no code
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_DCHECK_OP(val1, val2, >)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Null pointer check macro
|
||||
#define TORCH_CHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), false)
|
||||
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), true)
|
||||
#else // !NDEBUG
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
TORCH_CHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Deprecated macros
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
@ -291,6 +291,32 @@ namespace c10 {
|
||||
using fLB::FLAGS_logtostderr;
|
||||
using fLI::FLAGS_minloglevel;
|
||||
using fLI::FLAGS_v;
|
||||
|
||||
MessageLogger::MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal)
|
||||
: stream_(), severity_(severity), exit_on_fatal_(exit_on_fatal) {}
|
||||
|
||||
MessageLogger::~MessageLogger() noexcept(false) {
|
||||
if (severity_ == ::google::GLOG_FATAL) {
|
||||
DealWithFatal();
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream& MessageLogger::stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
void MessageLogger::DealWithFatal() {
|
||||
if (exit_on_fatal_) {
|
||||
LOG(FATAL) << stream_.str();
|
||||
} else {
|
||||
throw c10::Error(stream_.str(), nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
C10_DEFINE_int(
|
||||
@ -412,17 +438,16 @@ void ShowLogInfoToStderr() {
|
||||
FLAGS_caffe2_log_level = GLOG_INFO;
|
||||
}
|
||||
|
||||
MessageLogger::MessageLogger(const char* file, int line, int severity)
|
||||
: severity_(severity) {
|
||||
MessageLogger::MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal)
|
||||
: severity_(severity), exit_on_fatal_(exit_on_fatal) {
|
||||
if (severity_ < FLAGS_caffe2_log_level) {
|
||||
// Nothing needs to be logged.
|
||||
return;
|
||||
}
|
||||
#ifdef ANDROID
|
||||
tag_ = "native";
|
||||
#else // !ANDROID
|
||||
tag_ = "";
|
||||
#endif // ANDROID
|
||||
|
||||
time_t rawtime = 0;
|
||||
time(&rawtime);
|
||||
@ -458,7 +483,7 @@ MessageLogger::MessageLogger(const char* file, int line, int severity)
|
||||
}
|
||||
|
||||
// Output the contents of the stream to the proper channel on destruction.
|
||||
MessageLogger::~MessageLogger() {
|
||||
MessageLogger::~MessageLogger() noexcept(false) {
|
||||
if (severity_ < FLAGS_caffe2_log_level) {
|
||||
// Nothing needs to be logged.
|
||||
return;
|
||||
@ -498,6 +523,18 @@ MessageLogger::~MessageLogger() {
|
||||
}
|
||||
}
|
||||
|
||||
std::stringstream& MessageLogger::stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
void MessageLogger::DealWithFatal() {
|
||||
if (exit_on_fatal_) {
|
||||
abort();
|
||||
} else {
|
||||
throw c10::Error(stream_.str(), nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#endif // !C10_USE_GLOG
|
||||
|
||||
74
c10/util/logging_common.h
Normal file
74
c10/util/logging_common.h
Normal file
@ -0,0 +1,74 @@
|
||||
#ifndef C10_UTIL_LOGGING_COMMON_H_
|
||||
#define C10_UTIL_LOGGING_COMMON_H_
|
||||
|
||||
#include <c10/macros/Export.h>
|
||||
#include <sstream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// MessageLogger that throws exceptions instead of aborting (glog version)
|
||||
// or logs and may abort (non-glog version).
|
||||
class C10_API MessageLogger {
|
||||
public:
|
||||
MessageLogger(
|
||||
const char* file,
|
||||
int line,
|
||||
int severity,
|
||||
bool exit_on_fatal = true);
|
||||
~MessageLogger() noexcept(false);
|
||||
|
||||
// Return the stream associated with the logger object.
|
||||
std::stringstream& stream();
|
||||
|
||||
private:
|
||||
// When there is a fatal log, and fatal == true, we abort
|
||||
// otherwise, we throw.
|
||||
void DealWithFatal();
|
||||
|
||||
#if defined(ANDROID) && !defined(C10_USE_GLOG)
|
||||
const char* tag_{"native"};
|
||||
#endif
|
||||
std::stringstream stream_;
|
||||
int severity_;
|
||||
bool exit_on_fatal_;
|
||||
};
|
||||
|
||||
// This class is used to explicitly ignore values in the conditional
|
||||
// logging macros. This avoids compiler warnings like "value computed
|
||||
// is not used" and "statement has no effect".
|
||||
class C10_API LoggerVoidify {
|
||||
public:
|
||||
LoggerVoidify() = default;
|
||||
// This has to be an operator with a precedence lower than << but
|
||||
// higher than ?:
|
||||
void operator&(const std::ostream& s [[maybe_unused]]) {}
|
||||
};
|
||||
|
||||
// Forward declarations for CheckNotNull functions
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal = true);
|
||||
|
||||
template <typename T>
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal = true);
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal = true);
|
||||
|
||||
} // namespace c10
|
||||
|
||||
#endif // C10_UTIL_LOGGING_COMMON_H_
|
||||
@ -47,57 +47,53 @@ INSTANTIATE_FOR_CONTAINER(set)
|
||||
|
||||
#endif
|
||||
|
||||
#include <c10/util/logging_common.h>
|
||||
#include <glog/logging.h>
|
||||
|
||||
// Additional macros on top of glog
|
||||
#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
|
||||
#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2)
|
||||
#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2)
|
||||
#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2)
|
||||
#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2)
|
||||
#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2)
|
||||
namespace c10 {
|
||||
|
||||
#ifndef NDEBUG
|
||||
#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2)
|
||||
#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2)
|
||||
#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2)
|
||||
#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2)
|
||||
#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2)
|
||||
#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2)
|
||||
#else // !NDEBUG
|
||||
// These versions generate no code in optimized mode.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_EQ(val1, val2)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_NE(val1, val2)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_LE(val1, val2)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_LT(val1, val2)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_GE(val1, val2)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
DCHECK_GT(val1, val2)
|
||||
#endif // NDEBUG
|
||||
[[noreturn]] void ThrowEnforceNotMet(
|
||||
const char* file,
|
||||
const int line,
|
||||
const char* condition,
|
||||
const std::string& msg,
|
||||
const void* caller);
|
||||
|
||||
// Check that a pointer is not null.
|
||||
#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val)
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
if (t == nullptr) {
|
||||
MessageLogger(file, line, ::google::GLOG_FATAL, fatal).stream()
|
||||
<< "Check failed: '" << names << "' must be non NULL. ";
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only version of TORCH_CHECK_NOTNULL
|
||||
#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val)
|
||||
#else // !NDEBUG
|
||||
// Optimized version - generates no code.
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
DCHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
template <typename T>
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
// Log with source location information override (to be used in generic
|
||||
// warning/error handlers implemented as functions, not macros)
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/Flags.h>
|
||||
#include <c10/util/logging_common.h>
|
||||
|
||||
const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV";
|
||||
|
||||
@ -24,61 +25,40 @@ const int GLOG_ERROR = 2;
|
||||
const int GLOG_WARNING = 1;
|
||||
const int GLOG_INFO = 0;
|
||||
|
||||
class C10_API MessageLogger {
|
||||
public:
|
||||
MessageLogger(const char* file, int line, int severity);
|
||||
~MessageLogger();
|
||||
// Return the stream associated with the logger object.
|
||||
std::stringstream& stream() {
|
||||
return stream_;
|
||||
}
|
||||
|
||||
private:
|
||||
// When there is a fatal log, we simply abort.
|
||||
void DealWithFatal() {
|
||||
abort();
|
||||
}
|
||||
|
||||
const char* tag_;
|
||||
std::stringstream stream_;
|
||||
int severity_;
|
||||
};
|
||||
|
||||
// This class is used to explicitly ignore values in the conditional
|
||||
// logging macros. This avoids compiler warnings like "value computed
|
||||
// is not used" and "statement has no effect".
|
||||
class C10_API LoggerVoidify {
|
||||
public:
|
||||
LoggerVoidify() = default;
|
||||
// This has to be an operator with a precedence lower than << but
|
||||
// higher than ?:
|
||||
void operator&(const std::ostream& s [[maybe_unused]]) {}
|
||||
};
|
||||
|
||||
// Log a message and terminate.
|
||||
template <class T>
|
||||
void LogMessageFatal(const char* file, int line, const T& message) {
|
||||
MessageLogger(file, line, GLOG_FATAL).stream() << message;
|
||||
}
|
||||
|
||||
// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw
|
||||
// pointers and smart pointers.
|
||||
template <typename T>
|
||||
T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) {
|
||||
T& CheckNotNullCommon(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
if (t == nullptr) {
|
||||
LogMessageFatal(file, line, std::string(names));
|
||||
MessageLogger(file, line, GLOG_FATAL, fatal).stream()
|
||||
<< "Check failed: '" << names << "' must be non NULL. ";
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* CheckNotNull(const char* file, int line, const char* names, T* t) {
|
||||
return CheckNotNullCommon(file, line, names, t);
|
||||
T* CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T* t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T& CheckNotNull(const char* file, int line, const char* names, T& t) {
|
||||
return CheckNotNullCommon(file, line, names, t);
|
||||
T& CheckNotNull(
|
||||
const char* file,
|
||||
int line,
|
||||
const char* names,
|
||||
T& t,
|
||||
bool fatal) {
|
||||
return CheckNotNullCommon(file, line, names, t, fatal);
|
||||
}
|
||||
} // namespace c10
|
||||
|
||||
@ -136,65 +116,6 @@ static_assert(
|
||||
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
|
||||
#endif // NDEBUG
|
||||
|
||||
#define TORCH_CHECK_OP(val1, val2, op) \
|
||||
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
|
||||
<< (val1) << " vs. " << (val2) << ") "
|
||||
|
||||
// TORCH_CHECK_OP macro definitions
|
||||
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only versions of TORCH_CHECK_OP macros.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
|
||||
#else // !NDEBUG
|
||||
// These versions generate no code in optimized mode.
|
||||
#define TORCH_DCHECK_EQ(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, ==)
|
||||
#define TORCH_DCHECK_NE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, !=)
|
||||
#define TORCH_DCHECK_LE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, <=)
|
||||
#define TORCH_DCHECK_LT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, <)
|
||||
#define TORCH_DCHECK_GE(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, >=)
|
||||
#define TORCH_DCHECK_GT(val1, val2) \
|
||||
while (false) \
|
||||
TORCH_CHECK_OP(val1, val2, >)
|
||||
#endif // NDEBUG
|
||||
|
||||
// Check that a pointer is not null.
|
||||
#define TORCH_CHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull( \
|
||||
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
|
||||
|
||||
#ifndef NDEBUG
|
||||
// Debug only version of TORCH_CHECK_NOTNULL
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
::c10::CheckNotNull( \
|
||||
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
|
||||
#else // !NDEBUG
|
||||
// Optimized version - generates no code.
|
||||
#define TORCH_DCHECK_NOTNULL(val) \
|
||||
while (false) \
|
||||
TORCH_CHECK_NOTNULL(val)
|
||||
#endif // NDEBUG
|
||||
|
||||
// ---------------------- Support for std objects --------------------------
|
||||
// These are adapted from glog to support a limited set of logging capability
|
||||
// for STL objects.
|
||||
|
||||
@ -1240,6 +1240,10 @@ class XPUAllocator : public DeviceAllocator {
|
||||
c10::xpu::get_raw_device(dev_to_access));
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getMemoryInfo(DeviceIndex device) override {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented yet.");
|
||||
}
|
||||
|
||||
double getMemoryFraction(DeviceIndex device) {
|
||||
assertValidDevice(device);
|
||||
return device_allocators[device]->getMemoryFraction();
|
||||
|
||||
@ -1941,6 +1941,7 @@ if(BUILD_TEST)
|
||||
foreach(test_src ${Caffe2_XPU_TEST_SRCS})
|
||||
get_filename_component(test_name ${test_src} NAME_WE)
|
||||
add_executable(${test_name} "${test_src}")
|
||||
torch_compile_options(${test_name})
|
||||
target_link_libraries(${test_name} torch_library gtest_main)
|
||||
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
:nosignatures:
|
||||
|
||||
empty_cache
|
||||
get_memory_info
|
||||
max_memory_allocated
|
||||
max_memory_reserved
|
||||
memory_allocated
|
||||
|
||||
@ -172,9 +172,9 @@ ignore = [
|
||||
"SIM102", "SIM103", "SIM112", # flake8-simplify code styles
|
||||
"SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason
|
||||
"SIM108", # SIM108 ignored because we prefer if-else-block instead of ternary expression
|
||||
"SIM110",
|
||||
"SIM110", # Checks for for loops that can be replaced with a builtin function, like any or all.
|
||||
"SIM114", # Combine `if` branches using logical `or` operator
|
||||
"SIM115",
|
||||
"SIM115", # Checks for cases where files are opened without using a context manager.
|
||||
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
|
||||
"SIM117",
|
||||
"SIM118",
|
||||
@ -184,7 +184,6 @@ ignore = [
|
||||
"TC006",
|
||||
# TODO: Remove Python-3.10 specific suppressions
|
||||
"B905",
|
||||
"UP035",
|
||||
]
|
||||
select = [
|
||||
"B",
|
||||
|
||||
3
setup.py
3
setup.py
@ -1646,8 +1646,7 @@ def main() -> None:
|
||||
mirror_files_into_torchgen()
|
||||
if RUN_BUILD_DEPS:
|
||||
build_deps()
|
||||
|
||||
mirror_inductor_external_kernels()
|
||||
mirror_inductor_external_kernels()
|
||||
|
||||
(
|
||||
ext_modules,
|
||||
|
||||
@ -208,7 +208,7 @@ class _BaseDataSparsiferTestCase(TestCase):
|
||||
assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)
|
||||
|
||||
state1 = state_dict1["state"]
|
||||
for name in state1.keys():
|
||||
for name in state1:
|
||||
# compare mask
|
||||
assert name in sparsifier2.state
|
||||
assert "mask" in sparsifier2.state[name]
|
||||
|
||||
@ -119,7 +119,7 @@ class TestBaseSparsifier(TestCase):
|
||||
for idx in range(len(sparsifier0.groups)):
|
||||
mg0 = sparsifier0.groups[idx]
|
||||
mg1 = sparsifier1.groups[idx]
|
||||
for key in mg0.keys():
|
||||
for key in mg0:
|
||||
assert key in mg1
|
||||
if key == "module":
|
||||
# We cannot compare modules as they are different
|
||||
|
||||
@ -10,6 +10,8 @@ set(AOTI_ABI_CHECK_TEST_SRCS
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch_v2.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
|
||||
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp
|
||||
|
||||
82
test/cpp/aoti_abi_check/test_dispatch.cpp
Normal file
82
test/cpp/aoti_abi_check/test_dispatch.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <torch/headeronly/core/Dispatch.h>
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
|
||||
// MY_PRIVATE_CHECK_SELECTIVE_BUILD is a prelude to case block. For
|
||||
// testing, we do nothing:
|
||||
#define MY_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) /* empty */
|
||||
|
||||
#define MY_PRIVATE_CASE_TYPE_USING_HINT(...) \
|
||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
||||
MY_PRIVATE_CHECK_SELECTIVE_BUILD, __VA_ARGS__)
|
||||
|
||||
#define MY_DISPATCH_CASE(...) \
|
||||
THO_DISPATCH_CASE_TMPL(MY_PRIVATE_CASE_TYPE_USING_HINT, __VA_ARGS__)
|
||||
|
||||
// MY_RECORD_KERNEL_FUNCTION_DTYPE is a prelude to switch
|
||||
// statement. For testing, we just avoid unused variable warning:
|
||||
#define MY_RECORD_KERNEL_FUNCTION_DTYPE(DISPATCHNAME, ENUMTYPE) \
|
||||
(void)DISPATCHNAME
|
||||
|
||||
// MY_CHECK_NOT_IMPLEMENTED is called in switch default block. For
|
||||
// testing, we count case mismatches:
|
||||
#define MY_CHECK_NOT_IMPLEMENTED(...) default_count++
|
||||
|
||||
#define MY_DISPATCH_SWITCH(...) \
|
||||
THO_DISPATCH_SWITCH_TMPL( \
|
||||
MY_RECORD_KERNEL_FUNCTION_DTYPE, MY_CHECK_NOT_IMPLEMENTED, __VA_ARGS__)
|
||||
|
||||
// MY_CASE_FUNCTION is called in a case block. For testing, we count
|
||||
// case matches and ensure that scalar_t/index_t type is defined:
|
||||
#define MY_CASE_FUNCTION \
|
||||
[&] { \
|
||||
count++; \
|
||||
scalar_t tmp; \
|
||||
(void)tmp; \
|
||||
}
|
||||
#define MY_INDEX_CASE_FUNCTION \
|
||||
[&] { \
|
||||
count++; \
|
||||
index_t tmp; \
|
||||
(void)tmp; \
|
||||
}
|
||||
|
||||
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
|
||||
|
||||
#define MY_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||
THO_DISPATCH_V2_TMPL( \
|
||||
MY_DISPATCH_SWITCH, \
|
||||
MY_DISPATCH_CASE, \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_WRAP(BODY), \
|
||||
__VA_ARGS__)
|
||||
|
||||
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestDispatchV2, NAME) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t total_count = 0; \
|
||||
int8_t count = 0; \
|
||||
int8_t default_count = 0; \
|
||||
for (ScalarType t : \
|
||||
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
|
||||
total_count++; \
|
||||
MY_DISPATCH_V2(t, "test_my_dispatch_v2", MY_CASE_FUNCTION, __VA_ARGS__); \
|
||||
} \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
EXPECT_EQ(default_count + count, total_count); \
|
||||
}
|
||||
|
||||
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
|
||||
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
|
||||
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
|
||||
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
|
||||
|
||||
#undef DEFINE_ITEM
|
||||
45
test/cpp/aoti_abi_check/test_dispatch_v2.cpp
Normal file
45
test/cpp/aoti_abi_check/test_dispatch_v2.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
||||
#include <torch/headeronly/util/Exception.h>
|
||||
|
||||
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
|
||||
|
||||
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
|
||||
TEST(TestThoDispatchV2, NAME) { \
|
||||
using torch::headeronly::ScalarType; \
|
||||
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
|
||||
int8_t total_count = 0; \
|
||||
int8_t count = 0; \
|
||||
int8_t default_count = 0; \
|
||||
for (ScalarType t : \
|
||||
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
|
||||
total_count++; \
|
||||
try { \
|
||||
THO_DISPATCH_V2( \
|
||||
t, \
|
||||
"test_tho_dispatch_v2", \
|
||||
[&] { \
|
||||
count++; \
|
||||
scalar_t tmp; \
|
||||
(void)tmp; \
|
||||
}, \
|
||||
__VA_ARGS__); \
|
||||
} catch (...) { \
|
||||
default_count++; /* counts mismatches */ \
|
||||
} \
|
||||
} \
|
||||
EXPECT_EQ(count, EXPECTEDCOUNT); \
|
||||
EXPECT_EQ(default_count + count, total_count); \
|
||||
}
|
||||
|
||||
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
|
||||
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
|
||||
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
|
||||
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
|
||||
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
|
||||
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
|
||||
|
||||
#undef DEFINE_ITEM
|
||||
@ -67,13 +67,13 @@ Tensor sgd_out_of_place(
|
||||
|
||||
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = sgd_out_of_place(
|
||||
to<Tensor>(stack[0]),
|
||||
to<Tensor>(stack[1]),
|
||||
float(to<double>(stack[2])),
|
||||
to<double>(stack[3]),
|
||||
to<bool>(stack[4]));
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<Tensor>(stack[1]),
|
||||
float(torch::stable::detail::to<double>(stack[2])),
|
||||
torch::stable::detail::to<double>(stack[3]),
|
||||
torch::stable::detail::to<bool>(stack[4]));
|
||||
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
|
||||
@ -89,8 +89,8 @@ Tensor identity(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = identity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -108,14 +108,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
|
||||
Tensor my_abs(Tensor t) {
|
||||
const auto num_args = 1;
|
||||
StableIValue stack[num_args];
|
||||
stack[0] = from(t);
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::abs", "", stack);
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -132,21 +132,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
|
||||
|
||||
auto mf = aoti_torch_memory_format_contiguous_format();
|
||||
|
||||
stack[0] = from(t);
|
||||
stack[1] = from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = from(std::nullopt); // layout
|
||||
stack[3] = from(std::optional(device)); // device
|
||||
stack[4] = from(std::optional(false)); // pin_memory
|
||||
stack[5] = from(std::optional(mf)); // memory_format
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype
|
||||
stack[2] = torch::stable::detail::from(std::nullopt); // layout
|
||||
stack[3] = torch::stable::detail::from(std::optional(device)); // device
|
||||
stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory
|
||||
stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format
|
||||
|
||||
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
|
||||
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = from(res);
|
||||
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -159,28 +159,28 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
|
||||
StableIValue stack_exp[1];
|
||||
stack_exp[0] = from(t1);
|
||||
stack_exp[0] = torch::stable::detail::from(t1);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
|
||||
|
||||
StableIValue stack_neg[1];
|
||||
stack_neg[0] = from(t2);
|
||||
stack_neg[0] = torch::stable::detail::from(t2);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
|
||||
|
||||
StableIValue stack_is_leaf[1];
|
||||
stack_is_leaf[0] = from(t3);
|
||||
stack_is_leaf[0] = torch::stable::detail::from(t3);
|
||||
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
|
||||
|
||||
return std::make_tuple(
|
||||
to<Tensor>(stack_exp[0]),
|
||||
to<Tensor>(stack_neg[0]),
|
||||
to<bool>(stack_is_leaf[0]));
|
||||
torch::stable::detail::to<Tensor>(stack_exp[0]),
|
||||
torch::stable::detail::to<Tensor>(stack_neg[0]),
|
||||
torch::stable::detail::to<bool>(stack_is_leaf[0]));
|
||||
}
|
||||
|
||||
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
|
||||
stack[0] = from(std::get<0>(tuple));
|
||||
stack[1] = from(std::get<1>(tuple));
|
||||
stack[2] = from(std::get<2>(tuple));
|
||||
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
|
||||
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
|
||||
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -193,15 +193,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
Tensor neg_exp(Tensor t) {
|
||||
StableIValue stack[1];
|
||||
stack[0] = from(t);
|
||||
stack[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack);
|
||||
return to<Tensor>(stack[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = neg_exp(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -214,10 +214,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
|
||||
Tensor divide_neg_exp(Tensor t) {
|
||||
StableIValue stack_neg[1];
|
||||
stack_neg[0] = from(t);
|
||||
stack_neg[0] = torch::stable::detail::from(t);
|
||||
|
||||
StableIValue stack_exp[1];
|
||||
stack_exp[0] = from(t);
|
||||
stack_exp[0] = torch::stable::detail::from(t);
|
||||
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
|
||||
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
|
||||
|
||||
@ -225,12 +225,12 @@ Tensor divide_neg_exp(Tensor t) {
|
||||
stack_div[0] = stack_neg[0];
|
||||
stack_div[1] = stack_exp[0];
|
||||
aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div);
|
||||
return to<Tensor>(stack_div[0]);
|
||||
return torch::stable::detail::to<Tensor>(stack_div[0]);
|
||||
}
|
||||
|
||||
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -246,8 +246,8 @@ bool is_contiguous(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
bool res = is_contiguous(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -263,9 +263,9 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
|
||||
}
|
||||
|
||||
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
|
||||
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
|
||||
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_empty_like(Tensor t) {
|
||||
@ -273,8 +273,8 @@ Tensor my_empty_like(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_empty_like(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
bool my_is_cpu(Tensor t) {
|
||||
@ -283,8 +283,8 @@ bool my_is_cpu(Tensor t) {
|
||||
|
||||
|
||||
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_is_cpu(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor fill_infinity(Tensor t) {
|
||||
@ -296,8 +296,8 @@ void boxed_fill_infinity(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = fill_infinity(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_pad(Tensor t) {
|
||||
@ -310,8 +310,8 @@ void boxed_my_pad(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_pad(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
|
||||
@ -323,11 +323,11 @@ void boxed_my_narrow(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
auto res = my_narrow(
|
||||
to<Tensor>(stack[0]),
|
||||
to<int64_t>(stack[1]),
|
||||
to<int64_t>(stack[2]),
|
||||
to<int64_t>(stack[3]));
|
||||
stack[0] = from(res);
|
||||
torch::stable::detail::to<Tensor>(stack[0]),
|
||||
torch::stable::detail::to<int64_t>(stack[1]),
|
||||
torch::stable::detail::to<int64_t>(stack[2]),
|
||||
torch::stable::detail::to<int64_t>(stack[3]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
@ -342,8 +342,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
@ -352,8 +352,8 @@ Tensor my_new_zeros_dtype_variant(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
@ -361,8 +361,8 @@ Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
|
||||
}
|
||||
|
||||
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_copy_(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<bool>(stack[2]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
Tensor my_clone(Tensor t) {
|
||||
@ -370,8 +370,8 @@ Tensor my_clone(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
Tensor tensor_res = my_clone(to<Tensor>(stack[0]));
|
||||
stack[0] = from(tensor_res);
|
||||
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(tensor_res);
|
||||
}
|
||||
|
||||
|
||||
@ -408,8 +408,8 @@ Tensor my_zero_(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_zero_(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax(Tensor t) {
|
||||
@ -417,8 +417,8 @@ Tensor my_amax(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
Tensor my_amax_vec(Tensor t) {
|
||||
@ -426,8 +426,8 @@ Tensor my_amax_vec(Tensor t) {
|
||||
}
|
||||
|
||||
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = my_amax_vec(to<Tensor>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -464,8 +464,8 @@ void boxed_test_default_constructor(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
bool res = test_default_constructor(to<bool>(stack[0]));
|
||||
stack[0] = from(res);
|
||||
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -478,6 +478,56 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my_amax_vec", &boxed_my_amax_vec);
|
||||
}
|
||||
|
||||
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
|
||||
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
|
||||
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
|
||||
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
|
||||
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
|
||||
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
|
||||
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
|
||||
}
|
||||
|
||||
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
|
||||
}
|
||||
|
||||
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
|
||||
// This function tests that my__foreach_mul can take in std::initializer_lists
|
||||
// in addition to std::vectors.
|
||||
Tensor t1_1 = my_clone(t1);
|
||||
Tensor t1_2 = my_clone(t1);
|
||||
Tensor t2_1 = my_clone(t2);
|
||||
Tensor t2_2 = my_clone(t2);
|
||||
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
|
||||
}
|
||||
|
||||
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
|
||||
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
|
||||
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
|
||||
m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
|
||||
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
|
||||
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
|
||||
}
|
||||
|
||||
// Test functions for torch::stable::accelerator APIs
|
||||
|
||||
#ifdef LAE_USE_CUDA
|
||||
@ -500,8 +550,8 @@ void boxed_test_device_guard(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
|
||||
stack[0] = from(res);
|
||||
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_device_guard_set_index() {
|
||||
@ -520,7 +570,7 @@ void boxed_test_device_guard_set_index(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_device_guard_set_index();
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_stream(int32_t device_index) {
|
||||
@ -536,8 +586,8 @@ void boxed_test_stream(
|
||||
StableIValue* stack,
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
|
||||
stack[0] = from(res);
|
||||
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
int64_t test_get_current_device_index() {
|
||||
@ -549,7 +599,7 @@ void boxed_test_get_current_device_index(
|
||||
uint64_t num_args,
|
||||
uint64_t num_outputs) {
|
||||
int64_t res = test_get_current_device_index();
|
||||
stack[0] = from(res);
|
||||
stack[0] = torch::stable::detail::from(res);
|
||||
}
|
||||
|
||||
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
|
||||
@ -565,4 +615,5 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
|
||||
m.impl("test_stream", &boxed_test_stream);
|
||||
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
|
||||
}
|
||||
|
||||
#endif // LAE_USE_CUDA
|
||||
|
||||
@ -333,3 +333,45 @@ def my_new_zeros_dtype_variant(t) -> Tensor:
|
||||
Returns: New zeros tensor
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)
|
||||
|
||||
|
||||
def my__foreach_mul_(tensors, others) -> ():
|
||||
"""
|
||||
Updates tensors to be the result of pointwise multiplying with others.
|
||||
|
||||
Args:
|
||||
tensors: list of tensors
|
||||
others: list of tensors (with the same corresponding shapes as tensors)
|
||||
|
||||
Returns: nothing, tensors is updated in place.
|
||||
"""
|
||||
torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others)
|
||||
|
||||
|
||||
def my__foreach_mul(tensors, others) -> list[Tensor]:
|
||||
"""
|
||||
Returns a list of tensors that are the results of pointwise multiplying
|
||||
tensors and others.
|
||||
|
||||
Args:
|
||||
tensors: list of tensors
|
||||
others: list of tensors (with the same corresponding shapes as tensors)
|
||||
|
||||
Returns: list of multiplied tensors
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others)
|
||||
|
||||
|
||||
def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
|
||||
"""
|
||||
Returns a list of 2 tensors corresponding to the square of the inputs.
|
||||
|
||||
Args:
|
||||
t1: Tensor
|
||||
t2: Tensor
|
||||
|
||||
Returns: list of [t1^2, t2^2]
|
||||
"""
|
||||
return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default(
|
||||
t1, t2
|
||||
)
|
||||
|
||||
@ -367,6 +367,57 @@ if not IS_WINDOWS:
|
||||
self.assertNotEqual(result.data_ptr(), expected.data_ptr())
|
||||
self.assertEqual(result.stride(), expected.stride())
|
||||
|
||||
def test_my__foreach_mul_(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
N = 5
|
||||
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
tensors_c = [t.clone() for t in tensors]
|
||||
others = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
|
||||
libtorch_agnostic.ops.my__foreach_mul_(tensors, others)
|
||||
expected_values = torch._foreach_mul(tensors_c, others)
|
||||
|
||||
for tensor_t, expected_t in zip(tensors, expected_values):
|
||||
self.assertEqual(tensor_t, expected_t)
|
||||
|
||||
def test_my__foreach_mul(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
N = 5
|
||||
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
others = [torch.rand(32, 16, device=device) for _ in range(N)]
|
||||
|
||||
result = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
|
||||
expected = torch._foreach_mul(tensors, others)
|
||||
|
||||
for result_t, expected_t in zip(result, expected):
|
||||
self.assertEqual(result_t, expected_t)
|
||||
|
||||
def _make_cuda_tensors(prior_mem):
|
||||
cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
|
||||
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
|
||||
|
||||
expected = torch._foreach_mul(tensors, others)
|
||||
for result_t, expected_t in zip(cuda_res, expected):
|
||||
self.assertEqual(result_t, expected_t)
|
||||
|
||||
if tensors[0].is_cuda:
|
||||
init_mem = torch.cuda.memory_allocated(device)
|
||||
for _ in range(3):
|
||||
_make_cuda_tensors(init_mem)
|
||||
curr_mem = torch.cuda.memory_allocated(device)
|
||||
self.assertEqual(curr_mem, init_mem)
|
||||
|
||||
def test_make_tensor_clones_and_call_foreach(self, device):
|
||||
import libtorch_agnostic
|
||||
|
||||
t1 = torch.rand(2, 5, device=device)
|
||||
t2 = torch.rand(3, 4, device=device)
|
||||
result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2)
|
||||
self.assertEqual(result[0], t1 * t1)
|
||||
self.assertEqual(result[1], t2 * t2)
|
||||
|
||||
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from backend import get_custom_backend_library_path, Model, to_custom_backend
|
||||
@ -41,14 +40,11 @@ class TestCustomBackend(TestCase):
|
||||
self.test_execute()
|
||||
|
||||
# Save and load.
|
||||
f = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
f.close()
|
||||
torch.jit.save(self.model, f.name)
|
||||
loaded = torch.jit.load(f.name)
|
||||
finally:
|
||||
os.unlink(f.name)
|
||||
self.model = loaded
|
||||
self.model = loaded
|
||||
|
||||
# Test execution again.
|
||||
self.test_execute()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -144,16 +143,13 @@ def forward(self, arg0_1):
|
||||
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
||||
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
||||
# close the file after creation and try to remove it manually.
|
||||
file = tempfile.NamedTemporaryFile(delete=False)
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile() as file:
|
||||
file.close()
|
||||
model.save(file.name)
|
||||
loaded = torch.jit.load(file.name)
|
||||
finally:
|
||||
os.unlink(file.name)
|
||||
|
||||
output = loaded.forward(torch.ones(5))
|
||||
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
||||
output = loaded.forward(torch.ones(5))
|
||||
self.assertTrue(output.allclose(torch.ones(5) + 1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: fsdp"]
|
||||
import functools
|
||||
import os
|
||||
import unittest.mock
|
||||
import unittest
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.test_case import run_tests
|
||||
@ -37,9 +37,9 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
logger = logging.getLogger("torch.distributed._composable.fsdp")
|
||||
logger = logging.getLogger("torch.distributed.fsdp.fully_shard")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
device = {device_type.type}
|
||||
device = '{device_type.type}'
|
||||
torch.manual_seed(0)
|
||||
model = nn.Sequential(*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)])
|
||||
for layer in model:
|
||||
|
||||
@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
store=dist.FileStore(self.file_name, self.world_size),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_transformer(self):
|
||||
"""
|
||||
This tests that replicate works on a transformer model with fully_shard and replicate layers
|
||||
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Shard(dim=0),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_transformer_managed_modules(self):
|
||||
"""
|
||||
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
|
||||
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
replicate_model = replicate(replicate_model)
|
||||
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_replicate_tp_device_mesh(self):
|
||||
"""
|
||||
This tests that a user can pass in a device mesh to replicate a module
|
||||
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_train_replicate_fsdp(self):
|
||||
"""
|
||||
Tests that replicate_model has the same behavior as original model when training
|
||||
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
self.assertEqual(replicate_loss, loss)
|
||||
check_sharded_parity(self, model, replicate_model)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_train_parity_2d_mlp(self):
|
||||
"""
|
||||
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training
|
||||
|
||||
@ -80,7 +80,7 @@ class TestSACILP(TestCase):
|
||||
# postprocessing due to the fact that for ModTracker, the post backward hook
|
||||
# is not being called for modules whose inputs don't require gradients
|
||||
# TODO: fix this in ModTracker and ensure it does not lead to any perf regression
|
||||
if _ModState.POST_BW not in mod_stats.snapshots.keys():
|
||||
if _ModState.POST_BW not in mod_stats.snapshots:
|
||||
mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
|
||||
copy.deepcopy(last_snapshot)
|
||||
)
|
||||
|
||||
@ -16,7 +16,7 @@ from torch.distributed.argparse_util import check_env, env
|
||||
class ArgParseUtilTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# remove any lingering environment variables
|
||||
for e in os.environ.keys():
|
||||
for e in os.environ.keys(): # noqa: SIM118
|
||||
if e.startswith("PET_"):
|
||||
del os.environ[e]
|
||||
|
||||
|
||||
@ -207,7 +207,7 @@ class TestDefaultStager(TestCase):
|
||||
for i, result in enumerate(staged_results):
|
||||
self.assertIsInstance(result, dict)
|
||||
# Verify the result contains the expected keys
|
||||
for key in state_dicts[i].keys():
|
||||
for key in state_dicts[i]:
|
||||
self.assertIn(key, result)
|
||||
|
||||
stager.close()
|
||||
|
||||
@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
|
||||
"""
|
||||
Saving a dtensor with uneven shards.
|
||||
@ -436,6 +436,7 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_checkpointable_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
@ -498,6 +499,7 @@ class TestCheckpointableReshard(DTensorTestBase):
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
|
||||
"""
|
||||
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
|
||||
|
||||
@ -60,7 +60,7 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
self.assertEqual(
|
||||
sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())
|
||||
)
|
||||
for key in state_dict_to_save.keys():
|
||||
for key in state_dict_to_save:
|
||||
self.assertTrue(
|
||||
torch.equal(state_dict_to_save[key], state_dict_loaded[key])
|
||||
)
|
||||
@ -89,7 +89,7 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
self.assertEqual(
|
||||
sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())
|
||||
)
|
||||
for key in state_dict_to_save.keys():
|
||||
for key in state_dict_to_save:
|
||||
self.assertTrue(
|
||||
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
|
||||
)
|
||||
@ -116,7 +116,7 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
self.assertEqual(
|
||||
sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())
|
||||
)
|
||||
for key in state_dict_to_save.keys():
|
||||
for key in state_dict_to_save:
|
||||
self.assertTrue(
|
||||
torch.equal(state_dict_to_save[key], state_dict_loaded[key])
|
||||
)
|
||||
@ -156,7 +156,7 @@ class TestSingleRankSaveLoad(TestCase):
|
||||
self.assertEqual(
|
||||
sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())
|
||||
)
|
||||
for key in state_dict_to_save.keys():
|
||||
for key in state_dict_to_save:
|
||||
self.assertTrue(
|
||||
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
|
||||
)
|
||||
|
||||
@ -18,6 +18,7 @@ from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
|
||||
from torch.distributed.checkpoint.api import CheckpointException
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
_create_default_local_metadata,
|
||||
_validate_global_plan,
|
||||
create_default_global_save_plan,
|
||||
create_default_local_load_plan,
|
||||
create_default_local_save_plan,
|
||||
@ -28,6 +29,7 @@ from torch.distributed.checkpoint.filesystem import CURRENT_DCP_VERSION
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
TensorProperties,
|
||||
TensorStorageMetadata,
|
||||
@ -560,6 +562,32 @@ class TestPlannerHelpers(TestCase):
|
||||
self.assertTrue(_compare_save_plans(plan2, plan2))
|
||||
|
||||
|
||||
class TestValidateGlobalPlan(TestCase):
|
||||
def _make_metadata(self, chunks, size):
|
||||
storage = TensorStorageMetadata(
|
||||
properties=TensorProperties(dtype=torch.float32),
|
||||
size=torch.Size(size),
|
||||
chunks=chunks,
|
||||
)
|
||||
return Metadata(state_dict_metadata={"param": storage})
|
||||
|
||||
def test_non_overlapping_chunks(self):
|
||||
chunks = [
|
||||
ChunkStorageMetadata(offsets=torch.Size([i]), sizes=torch.Size([1]))
|
||||
for i in range(4)
|
||||
]
|
||||
metadata = self._make_metadata(chunks, [4])
|
||||
self.assertTrue(_validate_global_plan([SavePlan([])], metadata))
|
||||
|
||||
def test_detect_overlapping_chunks(self):
|
||||
chunks = [
|
||||
ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([2])),
|
||||
ChunkStorageMetadata(offsets=torch.Size([1]), sizes=torch.Size([2])),
|
||||
]
|
||||
metadata = self._make_metadata(chunks, [4])
|
||||
self.assertFalse(_validate_global_plan([SavePlan([])], metadata))
|
||||
|
||||
|
||||
class TestLoadPlanner(TestCase):
|
||||
@with_temp_dir
|
||||
def test_strict(self):
|
||||
|
||||
@ -769,7 +769,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
model_state_dict3 = copy.deepcopy(model_state_dict3)
|
||||
self.assertEqual(len(model_state_dict2), 2)
|
||||
self.assertEqual(len(model_state_dict3), 2)
|
||||
for key in model_state_dict3.keys():
|
||||
for key in model_state_dict3:
|
||||
full_fqn = f"l.{key}"
|
||||
value1 = model_state_dict1[full_fqn]
|
||||
value2 = model_state_dict2[full_fqn]
|
||||
@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
|
||||
self.assertEqual(cpu_model_value, meta_model_value)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
|
||||
# This test verifies that we can set model state dict by a meta device model
|
||||
# With the correlated changes in state_dict, meta device model should be accepted
|
||||
|
||||
@ -479,6 +479,7 @@ class TestFSDPMiscMultiProcess(FSDPTest):
|
||||
for (n, p), (n_prev, p_prev) in zip(
|
||||
fsdp_overlap.named_parameters(), fsdp_overlap_prev_params
|
||||
):
|
||||
self.assertEqual(n, n_prev)
|
||||
self.assertNotEqual(
|
||||
p,
|
||||
p_prev,
|
||||
|
||||
@ -587,9 +587,7 @@ class TestFSDPStateDict(FSDPTest):
|
||||
model, cpu_offload.offload_params, fp16
|
||||
)
|
||||
|
||||
ignore_keys = [
|
||||
k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
|
||||
]
|
||||
ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k]
|
||||
|
||||
self._validate_state_dict_contents(
|
||||
model,
|
||||
@ -910,7 +908,7 @@ class TestFSDPStateDict(FSDPTest):
|
||||
with sd_mgr:
|
||||
fsdp_state_dict = model.state_dict()
|
||||
|
||||
ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k]
|
||||
ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k]
|
||||
self._validate_state_dict_contents(
|
||||
model,
|
||||
fsdp_state_dict,
|
||||
@ -959,9 +957,7 @@ class TestFSDPStateDict(FSDPTest):
|
||||
# Full name of linear_skip param tensors in SkipModel, as would be
|
||||
# stored in checkpoint.
|
||||
linear_skip_tensor_names = [
|
||||
k
|
||||
for k in dict(module.named_parameters()).keys()
|
||||
if LINEAR_SKIP in k
|
||||
k for k in dict(module.named_parameters()) if LINEAR_SKIP in k
|
||||
]
|
||||
# skip SkipModule
|
||||
linear_skip = getattr(module, LINEAR_SKIP)
|
||||
|
||||
@ -137,7 +137,7 @@ class ElasticLaunchTest(unittest.TestCase):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
# remove any lingering environment variables.
|
||||
for env in os.environ.keys():
|
||||
for env in os.environ.keys(): # noqa: SIM118
|
||||
if env.startswith("PET_"):
|
||||
del os.environ[env]
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ class ElasticLaunchTest(TestCase):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
# remove any lingering environment variables
|
||||
for env in os.environ.keys():
|
||||
for env in os.environ.keys(): # noqa: SIM118
|
||||
if env.startswith("PET_"):
|
||||
del os.environ[env]
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from torch.nn.modules.loss import MSELoss
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcContinuousTest,
|
||||
requires_accelerator_dist_backend,
|
||||
skip_if_lt_x_gpu,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
check_leaked_tensors,
|
||||
@ -46,6 +47,7 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_MULTIACCELERATOR,
|
||||
)
|
||||
|
||||
|
||||
@ -56,7 +58,6 @@ batch_size = 64
|
||||
torch.manual_seed(0)
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
backend = dist.get_default_backend_for_device(device_type)
|
||||
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -231,6 +232,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_forward_only(self, ScheduleClass):
|
||||
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
|
||||
x_clone = x.clone()
|
||||
@ -274,6 +276,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_eval_inference_mode(self, ScheduleClass):
|
||||
num_microbatches = 4
|
||||
if ScheduleClass in [
|
||||
@ -351,6 +354,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_return_output(self, ScheduleClass):
|
||||
num_microbatches = 4
|
||||
if ScheduleClass in [
|
||||
@ -406,6 +410,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_multi_iter(self, ScheduleClass):
|
||||
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
chunks = 4
|
||||
@ -429,6 +434,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_kwargs_with_tracer(self, ScheduleClass):
|
||||
mod = ModelWithKwargs(d_hid, splits=self.world_size)
|
||||
mod.to(self.device)
|
||||
@ -481,6 +487,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_tracer(self, ScheduleClass):
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
|
||||
@ -523,6 +530,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
|
||||
@parametrize("shape_inference", [True, False])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_manual(self, ScheduleClass, shape_inference):
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
|
||||
|
||||
@ -586,6 +594,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -650,6 +659,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -736,6 +746,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_v_shape_schedules(self, schedule_class):
|
||||
n_stages = 8
|
||||
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
|
||||
@ -780,6 +791,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_custom_function_callback(self):
|
||||
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
|
||||
n_stages = 8
|
||||
@ -979,6 +991,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
"ScheduleClass",
|
||||
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
@ -1072,6 +1085,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleVShaped, ScheduleUnbalanced],
|
||||
)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_non_symmetric_stage_ids(self, schedule_class):
|
||||
n_stages = schedule_class.n_stages
|
||||
rank_stages = schedule_class.rank_stages
|
||||
@ -1121,6 +1135,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
|
||||
n_stages = 2
|
||||
stages_per_rank = 1
|
||||
@ -1181,6 +1196,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
|
||||
)
|
||||
@parametrize("ScheduleClass", [ScheduleWithW])
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
|
||||
n_stages = ScheduleClass.n_stages
|
||||
num_microbatches = ScheduleClass.num_microbatches
|
||||
|
||||
@ -24,6 +24,7 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
skip_but_pass_in_sandcastle_if,
|
||||
TEST_MULTIACCELERATOR,
|
||||
)
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
@ -34,7 +35,6 @@ chunks = 8
|
||||
|
||||
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
||||
backend = dist.get_default_backend_for_device(device_type)
|
||||
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@ -464,6 +464,25 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||
run(g, 64, 8)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dtensor_requires_grad_recompile(self):
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
@torch.compile(backend=cnt, fullgraph=True)
|
||||
def f(x):
|
||||
y = x * x
|
||||
return y.to_local()
|
||||
|
||||
full_x = torch.randn(8, 8, requires_grad=False)
|
||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
||||
f(x)
|
||||
|
||||
full_x = torch.randn(8, 8, requires_grad=True)
|
||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
||||
f(x)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dtensor_attribute_access_on_intermediate(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
||||
@ -535,6 +535,19 @@ class DTensorExportTest(TestCase):
|
||||
|
||||
self.assertEqual(fn(z), gm(z)[0])
|
||||
|
||||
def test_dtensor_data_dependent_index(self):
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x[y]
|
||||
|
||||
x = torch.randn(10)
|
||||
y = torch.randint(1, (10,)).bool()
|
||||
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
|
||||
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
||||
@ -26,6 +26,7 @@ from torch.distributed.tensor.parallel import (
|
||||
RowwiseParallel,
|
||||
SequenceParallel,
|
||||
)
|
||||
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
create_local_tensor_test_class,
|
||||
@ -764,6 +765,7 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
self.assertEqual(grad1_norm.device_mesh, mesh_y)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_foreach_add_different_mesh(self):
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(
|
||||
|
||||
@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
|
||||
self.assertEqual(
|
||||
comm_mode.get_comm_counts(),
|
||||
{
|
||||
torch.ops.c10d_functional.all_gather_into_tensor: 4,
|
||||
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
|
||||
},
|
||||
)
|
||||
expected_cost = [
|
||||
|
||||
@ -54,6 +54,7 @@ def apply_reordering_and_get_graph(graph, out_li) -> None:
|
||||
"max_compute_pre_fetch",
|
||||
"custom_runtime_estimation",
|
||||
"insert_overlap_deps",
|
||||
"collective_estimator",
|
||||
)
|
||||
for key in config_keys:
|
||||
if (val := getattr(dist_opts, key)) is not None:
|
||||
@ -943,6 +944,50 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
|
||||
correct = func(inputs_a, inputs_b, ranks=ranks)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
def test_collective_benchmarking_with_real_pg(self):
|
||||
"""Test collective benchmarking with real process group (falls back on fake)."""
|
||||
|
||||
def func(a):
|
||||
# Test all three collective types with 8x8 (power of 2 size = 256 elements = 1024 bytes for fp32)
|
||||
ar = _functional_collectives.all_reduce(a, "sum", "0")
|
||||
ag = _functional_collectives.all_gather_tensor(
|
||||
a, 0, list(range(self.world_size))
|
||||
)
|
||||
rs = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0")
|
||||
|
||||
b = torch.matmul(a, a)
|
||||
c = torch.matmul(ar, b)
|
||||
return c.sum() + ag.sum() + rs.sum()
|
||||
|
||||
patches = {
|
||||
**get_patches(),
|
||||
"aten_distributed_optimizations.collective_estimator": "benchmark",
|
||||
"aten_distributed_optimizations.custom_runtime_estimation": None, # Remove custom estimation so benchmarking happens
|
||||
}
|
||||
|
||||
with _dynamo_dist_per_rank_init(
|
||||
self.rank,
|
||||
self.world_size,
|
||||
self.backend(device_type),
|
||||
fake_pg=not at_least_x_gpu(2),
|
||||
):
|
||||
inputs = torch.ones(8, 8, dtype=torch.float, device=device_type) + self.rank
|
||||
|
||||
with torch._inductor.config.patch(patches):
|
||||
compiled = torch.compile(func)
|
||||
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
|
||||
|
||||
# Verify all three collective types are present
|
||||
FileCheck().check("all_reduce").check("all_gather").check(
|
||||
"reduce_scatter"
|
||||
).run(aten_graph_str)
|
||||
|
||||
# Test passes if compilation succeeded with benchmarking enabled
|
||||
# Cache verification is tricky due to multiprocess test setup
|
||||
correct = func(inputs)
|
||||
self.assertTrue(same(out, correct))
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@torch._inductor.config.patch(get_bucket_patches())
|
||||
def test_multidtype_bucketing(self):
|
||||
|
||||
@ -485,7 +485,7 @@ elif TEST_XPU:
|
||||
def exit_if_lt_x_accelerators(x):
|
||||
if torch.accelerator.is_available():
|
||||
if torch.accelerator.device_count() < x:
|
||||
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
|
||||
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
||||
|
||||
|
||||
def with_comms(func=None):
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
# flake8: noqa: B950
|
||||
# flake8: noqa: E731
|
||||
import contextlib
|
||||
import copy
|
||||
import functools
|
||||
@ -15,7 +17,11 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch._dynamo.backends.common import aot_autograd
|
||||
from torch._dynamo.testing import CompileCounterWithBackend
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
CompileCounterWithBackend,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
|
||||
@ -1649,6 +1655,43 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
|
||||
self.assertEqual(opt_fn(x), fn(x))
|
||||
|
||||
def test_return_same_element_twice(self):
|
||||
def gn(x):
|
||||
y = torch.sin(x)
|
||||
return y, y
|
||||
|
||||
def fn(x):
|
||||
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[4, 4]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
|
||||
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
|
||||
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[4, 4]"):
|
||||
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (y, y)
|
||||
""",
|
||||
)
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_mutation(self):
|
||||
counter = 0
|
||||
@ -1672,6 +1715,114 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
|
||||
# The mutation is not reapplied in the backward because the flag was on.
|
||||
self.assertEqual(counter, 1)
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_list_mutation(self):
|
||||
def gn(x, z):
|
||||
out = x.sin()
|
||||
z.append(out)
|
||||
return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out
|
||||
|
||||
def fn(x):
|
||||
z = []
|
||||
|
||||
out1, out2 = torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
z,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
return out1, z[0]
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
|
||||
def test_nonlocal_list_mutation_hidden(self):
|
||||
def gn(x, z):
|
||||
o = torch.matmul(x, x) @ x
|
||||
out = x.sin()
|
||||
z.append(out)
|
||||
return torch.cos(torch.sin(o)), torch.sin(x)
|
||||
|
||||
def fn(x):
|
||||
z = []
|
||||
|
||||
outs = torch.utils.checkpoint.checkpoint(
|
||||
gn,
|
||||
x,
|
||||
z,
|
||||
use_reentrant=False,
|
||||
)
|
||||
out1 = outs[0]
|
||||
# Check that the extra output pytree handling is done properly
|
||||
out2 = outs[-1]
|
||||
|
||||
return out1 + out2, z[0]
|
||||
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
ref = fn(x)
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(ref[1], res[1])
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[4, 4]"):
|
||||
l_x_ = L_x_
|
||||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None
|
||||
out1: "f32[4, 4]" = tag_activation_checkpoint[0]
|
||||
out2: "f32[4, 4]" = tag_activation_checkpoint[1]
|
||||
getitem_4: "f32[4, 4]" = tag_activation_checkpoint[4]; tag_activation_checkpoint = None
|
||||
|
||||
add: "f32[4, 4]" = out1 + out2; out1 = out2 = None
|
||||
return (add, getitem_4)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[4, 4]"):
|
||||
matmul: "f32[4, 4]" = torch.matmul(l_x_, l_x_)
|
||||
o: "f32[4, 4]" = matmul @ l_x_
|
||||
|
||||
out: "f32[4, 4]" = l_x_.sin()
|
||||
|
||||
sin_1: "f32[4, 4]" = torch.sin(o)
|
||||
child: "f32[4, 4]" = torch.cos(sin_1)
|
||||
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
|
||||
return (child, child_1, matmul, o, out, sin_1)
|
||||
""",
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[4, 4]"):
|
||||
mm: "f32[4, 4]" = torch.ops.aten.mm.default(primals_1, primals_1)
|
||||
mm_1: "f32[4, 4]" = torch.ops.aten.mm.default(mm, primals_1); mm = None
|
||||
|
||||
sin: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1)
|
||||
|
||||
sin_1: "f32[4, 4]" = torch.ops.aten.sin.default(mm_1); mm_1 = None
|
||||
cos: "f32[4, 4]" = torch.ops.aten.cos.default(sin_1); sin_1 = None
|
||||
sin_2: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1)
|
||||
|
||||
add: "f32[4, 4]" = torch.ops.aten.add.Tensor(cos, sin_2); cos = sin_2 = None
|
||||
return (add, sin, primals_1)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
devices = ["cuda", "hpu"]
|
||||
instantiate_device_type_tests(
|
||||
|
||||
@ -408,6 +408,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
self.assertEqual(ref0, res0)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@unittest.skip(
|
||||
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
|
||||
)
|
||||
def test_cuda_event_reconstruct(self):
|
||||
def fn(x):
|
||||
e = torch.cuda.Event()
|
||||
@ -425,6 +428,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@unittest.skip(
|
||||
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
|
||||
)
|
||||
def test_cuda_event_across_graph_break(self):
|
||||
def fn(x):
|
||||
e = torch.cuda.Event()
|
||||
@ -446,9 +452,12 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref[0], res[0])
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
self.assertEqual(cnts.op_count, 9)
|
||||
self.assertEqual(cnts.op_count, 10)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@unittest.skip(
|
||||
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
|
||||
)
|
||||
def test_cuda_event_created_outside_of_graph(self):
|
||||
user_stream = torch.cuda.Stream()
|
||||
event = torch.cuda.Event()
|
||||
@ -478,9 +487,12 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
||||
res = run_iters(func, compile=True)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
||||
@unittest.skip(
|
||||
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
|
||||
)
|
||||
def test_cuda_event_method_create_stream_outside_of_compile(self):
|
||||
def fn(x, cur_stream, new_stream):
|
||||
x = torch.mul(x, 1)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user