Compare commits

..

1 Commits

435 changed files with 6033 additions and 7777 deletions

View File

@ -36,7 +36,11 @@ case ${DOCKER_TAG_PREFIX} in
;;
rocm*)
BASE_TARGET=rocm
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
;;
*)

View File

@ -168,18 +168,6 @@ case "$tag" in
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-py3.11-clang12)
ANACONDA_PYTHON_VERSION=3.11
CLANG_VERSION=12
VISION=no
TRITON=no
;;
pytorch-linux-jammy-py3.12-clang12)
ANACONDA_PYTHON_VERSION=3.12
CLANG_VERSION=12
VISION=no
TRITON=no
;;
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
if [[ $tag =~ "jammy" ]]; then
ANACONDA_PYTHON_VERSION=3.10
@ -207,9 +195,9 @@ case "$tag" in
NINJA_VERSION=1.9.0
TRITON=yes
;;
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=13
GCC_VERSION=11
VISION=yes
XPU_VERSION=2025.2
NINJA_VERSION=1.9.0
@ -260,12 +248,6 @@ case "$tag" in
HALIDE=yes
TRITON=yes
;;
pytorch-linux-jammy-cuda13.0-py3.12-pallas)
CUDA_VERSION=13.0.0
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=11
PALLAS=yes
;;
pytorch-linux-jammy-py3.12-triton-cpu)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.12
@ -387,7 +369,6 @@ docker build \
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
--build-arg "EXECUTORCH=${EXECUTORCH}" \
--build-arg "HALIDE=${HALIDE}" \
--build-arg "PALLAS=${PALLAS}" \
--build-arg "XPU_VERSION=${XPU_VERSION}" \
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
--build-arg "ACL=${ACL:-}" \

View File

@ -1 +0,0 @@
0.8.0

View File

@ -1,40 +0,0 @@
#!/bin/bash
set -ex
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
# Get the pinned JAX version (same for all CUDA versions)
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
function install_jax_12() {
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
}
function install_jax_13() {
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Verify installation
python -c "import jax" # check for errors
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
}
# idiomatic parameter and option handling in sh
while test $# -gt 0
do
case "$1" in
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
;;
13.0|13.0.*) install_jax_13;
;;
*) echo "bad argument $1"; exit 1
;;
esac
shift
done

View File

@ -9,7 +9,7 @@ set -xe
function install_ubuntu() {
. /etc/os-release
if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then
if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then
echo "Ubuntu version ${VERSION_CODENAME} not supported"
exit
fi
@ -35,24 +35,25 @@ function install_ubuntu() {
# The xpu-smi packages
apt-get install -y flex bison xpu-smi
# Compute and Media Runtimes
if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
# Compute and Media Runtimes
apt-get install -y \
intel-opencl-icd libze-intel-gpu1 libze1 \
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
intel-opencl-icd intel-level-zero-gpu level-zero \
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
else # jammy
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
# Development Packages
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
else # rolling driver
apt-get install -y \
intel-opencl-icd libze-intel-gpu1 libze1 \
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
fi
# Development Packages
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
# Install Intel Support Packages
apt-get install -y ${XPU_PACKAGES}
@ -65,7 +66,7 @@ function install_ubuntu() {
function install_rhel() {
. /etc/os-release
if [[ "${ID}" == "rhel" ]]; then
if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
echo "RHEL version ${VERSION_ID} not supported"
exit
fi
@ -146,7 +147,7 @@ function install_sles() {
XPU_DRIVER_VERSION=""
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
# Use GPU driver LTS releases
XPU_DRIVER_VERSION="/lts/2523"
XPU_DRIVER_VERSION="/lts/2350"
fi
# Default use Intel® oneAPI Deep Learning Essentials 2025.1

View File

@ -49,7 +49,11 @@ case ${DOCKER_TAG_PREFIX} in
fi
BASE_TARGET=rocm
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
;;
*)

View File

@ -87,7 +87,11 @@ case ${image} in
MANY_LINUX_VERSION="2_28"
DEVTOOLSET_VERSION="11"
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
# add gfx950, gfx115x conditionally starting in ROCm 7.0
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
fi
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
;;
manylinux2_28-builder:xpu)

View File

@ -143,15 +143,6 @@ COPY ci_commit_pins/halide.txt halide.txt
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
RUN rm install_halide.sh common_utils.sh halide.txt
ARG PALLAS
ARG CUDA_VERSION
# Install JAX with CUDA support (for Pallas)
COPY ./common/install_jax.sh install_jax.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
ARG ONNX
# Install ONNX dependencies
COPY ./common/install_onnx.sh ./common/common_utils.sh ./

View File

@ -337,7 +337,7 @@ test_python() {
test_python_smoke() {
# Smoke tests for H100/B200
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running
assert_git_not_dirty
}
@ -824,11 +824,6 @@ test_inductor_halide() {
assert_git_not_dirty
}
test_inductor_pallas() {
python test/run_test.py --include inductor/test_pallas.py --verbose
assert_git_not_dirty
}
test_inductor_triton_cpu() {
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
assert_git_not_dirty
@ -1729,8 +1724,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
test_inductor_halide
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
test_inductor_pallas
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
test_inductor_triton_cpu
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then

View File

@ -1 +1 @@
ca2212438fdd8ce29b66999ed70ed54b0f9372d1
cfbc5c2f1c798991715a6b06bb3ce46478c4487c

9
.github/labeler.yml vendored
View File

@ -138,8 +138,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -149,8 +148,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/**/*cublas*
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
@ -160,8 +158,7 @@
- test/test_matmul_cuda.py
- test/test_scaled_matmul_cuda.py
- test/inductor/test_fp8.py
- aten/src/ATen/native/cuda/*Blas.cpp
- aten/src/ATen/cuda/CUDA*Blas.*
- aten/src/ATen/native/cuda/Blas.cpp
- torch/_inductor/kernel/mm.py
- test/inductor/test_max_autotune.py
- third_party/fbgemm

View File

@ -10,4 +10,3 @@
pathFilter:
- 'torch/csrc/inductor/aoti_torch/c/*'
- 'torch/csrc/inductor/aoti_torch/generated/*'
- 'torch/csrc/stable/c/*'

View File

@ -2,8 +2,8 @@ tracking_issue: 24422
ciflow_tracking_issue: 64124
ciflow_push_tags:
- ciflow/b200
- ciflow/b200-distributed
- ciflow/b200-symm-mem
- ciflow/b200-distributed
- ciflow/binaries
- ciflow/binaries_libtorch
- ciflow/binaries_wheel
@ -22,8 +22,6 @@ ciflow_push_tags:
- ciflow/inductor-perf-test-nightly-xpu
- ciflow/inductor-periodic
- ciflow/inductor-rocm
- ciflow/inductor-rocm-mi200
- ciflow/inductor-rocm-mi300
- ciflow/linux-aarch64
- ciflow/mps
- ciflow/nightly
@ -35,13 +33,11 @@ ciflow_push_tags:
- ciflow/quantization-periodic
- ciflow/riscv64
- ciflow/rocm
- ciflow/rocm-mi200
- ciflow/rocm-mi300
- ciflow/rocm-mi355
- ciflow/rocm-navi31
- ciflow/s390
- ciflow/slow
- ciflow/slow-rocm-mi200
- ciflow/torchbench
- ciflow/triton_binaries
- ciflow/trunk

View File

@ -56,8 +56,6 @@ jobs:
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
pytorch-linux-jammy-py3.10-clang12,
pytorch-linux-jammy-py3.11-clang12,
pytorch-linux-jammy-py3.12-clang12,
pytorch-linux-jammy-py3.13-clang12,
pytorch-linux-jammy-py3.14-clang12,
pytorch-linux-jammy-rocm-n-py3,
@ -67,10 +65,9 @@ jobs:
pytorch-linux-jammy-py3.10-gcc11,
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
pytorch-linux-jammy-py3.12-halide,
pytorch-linux-jammy-cuda13.0-py3.12-pallas,
pytorch-linux-jammy-xpu-n-1-py3,
pytorch-linux-noble-xpu-n-py3,
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
pytorch-linux-jammy-xpu-n-py3,
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
pytorch-linux-jammy-py3-clang18-asan,
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,

View File

@ -83,8 +83,8 @@ jobs:
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
@ -117,7 +117,7 @@ jobs:
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-noble-xpu-n-py3.10
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
@ -137,7 +137,7 @@ jobs:
uses: ./.github/workflows/_xpu-test.yml
needs: xpu-n-py3_10-inductor-benchmark-build
with:
build-environment: linux-noble-xpu-n-py3.10
build-environment: linux-jammy-xpu-n-py3.10
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}

View File

@ -7,7 +7,6 @@ on:
- release/*
tags:
- ciflow/inductor-rocm/*
- ciflow/inductor-rocm-mi300/*
workflow_dispatch:
concurrency:

View File

@ -7,7 +7,7 @@ on:
branches:
- release/*
tags:
- ciflow/inductor-rocm-mi200/*
- ciflow/inductor-rocm/*
workflow_dispatch:
concurrency:

View File

@ -81,32 +81,6 @@ jobs:
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
secrets: inherit
inductor-pallas-build:
name: inductor-pallas-build
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-py3.12-pallas
cuda-arch-list: '8.9'
runner: linux.8xlarge.memory
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
test-matrix: |
{ include: [
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" },
]}
secrets: inherit
inductor-pallas-test:
name: inductor-pallas-test
uses: ./.github/workflows/_linux-test.yml
needs: inductor-pallas-build
with:
build-environment: linux-jammy-py3.12-gcc11
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
secrets: inherit
inductor-triton-cpu-build:
name: inductor-triton-cpu-build
uses: ./.github/workflows/_linux-build.yml

View File

@ -11,6 +11,7 @@ on:
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
push:
tags:
- ciflow/periodic/*
- ciflow/periodic-rocm-mi200/*
branches:
- release/*

View File

@ -11,7 +11,6 @@ on:
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
push:
tags:
- ciflow/periodic/*
- ciflow/periodic-rocm-mi300/*
branches:
- release/*

View File

@ -342,16 +342,16 @@ jobs:
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
secrets: inherit
linux-noble-xpu-n-py3_10-build:
name: linux-noble-xpu-n-py3.10
linux-jammy-xpu-n-py3_10-build:
name: linux-jammy-xpu-n-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
# This should sync with the build in xpu.yml but xpu uses a larger runner
# sync-tag: linux-xpu-n-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },

View File

@ -6,7 +6,6 @@ on:
- main
- release/*
tags:
- ciflow/rocm/*
- ciflow/rocm-mi300/*
workflow_dispatch:
schedule:

View File

@ -5,7 +5,7 @@ on:
branches:
- release/*
tags:
- ciflow/rocm-mi200/*
- ciflow/rocm/*
workflow_dispatch:
schedule:
- cron: 29 8 * * * # about 1:29am PDT

View File

@ -1,81 +0,0 @@
# This workflow is dedicated to host slow jobs that are run only periodically because
# they are too slow to run in every commit. The list of slow tests can be found in
# https://github.com/pytorch/test-infra/blob/generated-stats/stats/slow-tests.json
name: slow-rocm-mi200
on:
push:
branches:
- release/*
tags:
- ciflow/slow/*
- ciflow/slow-rocm-mi200/*
schedule:
- cron: 0 */3 * * *
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
llm-td:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/llm_td_retrieval.yml
permissions:
id-token: write
contents: read
target-determination:
name: before-test
uses: ./.github/workflows/target_determination.yml
needs: llm-td
permissions:
id-token: write
contents: read
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -105,6 +105,36 @@ jobs:
test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-rocm-py3_10-build:
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-rocm-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
]}
secrets: inherit
linux-jammy-rocm-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-jammy-rocm-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-jammy-rocm-py3_10-build
- target-determination
with:
build-environment: linux-jammy-rocm-py3.10
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3_10-clang18-asan-build:
name: linux-jammy-py3.10-clang18-asan
uses: ./.github/workflows/_linux-build.yml

View File

@ -11,16 +11,15 @@ on:
- inductor
- unstable
- slow
- slow-rocm-mi200
- unstable-periodic
- inductor-periodic
- rocm-mi200
- rocm
- rocm-mi300
- rocm-mi355
- inductor-micro-benchmark
- inductor-micro-benchmark-x86
- inductor-cu124
- inductor-rocm-mi200
- inductor-rocm
- inductor-rocm-mi300
- mac-mps
- linux-aarch64

View File

@ -47,15 +47,15 @@ jobs:
]}
secrets: inherit
linux-noble-xpu-n-py3_10-build:
name: linux-noble-xpu-n-py3.10
linux-jammy-xpu-n-py3_10-build:
name: linux-jammy-xpu-n-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
sync-tag: linux-xpu-n-build
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
build-environment: linux-noble-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
build-environment: linux-jammy-xpu-n-py3.10
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
runner: linux.c7i.12xlarge
test-matrix: |
{ include: [
@ -74,17 +74,17 @@ jobs:
]}
secrets: inherit
linux-noble-xpu-n-py3_10-test:
name: linux-noble-xpu-n-py3.10
linux-jammy-xpu-n-py3_10-test:
name: linux-jammy-xpu-n-py3.10
uses: ./.github/workflows/_xpu-test.yml
needs: linux-noble-xpu-n-py3_10-build
needs: linux-jammy-xpu-n-py3_10-build
permissions:
id-token: write
contents: read
with:
build-environment: linux-noble-xpu-n-py3.10
docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }}
build-environment: linux-jammy-xpu-n-py3.10
docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }}
secrets: inherit
windows-xpu-n-1-build:

1
.gitignore vendored
View File

@ -127,6 +127,7 @@ torch/test/
torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
torch/_inductor/kernel/vendored_templates/*
minifier_launcher.py
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*

View File

@ -143,8 +143,7 @@ init_command = [
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
'numpy==2.3.4 ; python_version >= "3.14"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'pyrefly==0.36.2',
'sympy==1.13.3',
@ -1402,7 +1401,7 @@ init_command = [
'--dry-run={{DRYRUN}}',
'usort==1.0.8.post1',
'isort==6.0.1',
'ruff==0.14.4', # sync with RUFF
'ruff==0.13.1', # sync with RUFF
]
is_formatter = true
@ -1537,7 +1536,7 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'ruff==0.14.4', # sync with PYFMT
'ruff==0.13.1', # sync with PYFMT
]
is_formatter = true

View File

@ -210,12 +210,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
/test/inductor/test_flex_attention.py @drisspg
/test/inductor/test_flex_decoding.py @drisspg
# Low Precision & Grouped GEMMs
# Low Precision GEMMs
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
/test/test_scaled_matmul_cuda.py @drisspg @slayton58

View File

@ -174,12 +174,6 @@ class TORCH_API Context {
static long versionCuDNN() {
return detail::getCUDAHooks().versionCuDNN();
}
static long versionRuntimeCuDNN() {
return detail::getCUDAHooks().versionRuntimeCuDNN();
}
static long versionCuDNNFrontend() {
return detail::getCUDAHooks().versionCuDNNFrontend();
}
static bool hasCuSOLVER() {
return detail::getCUDAHooks().hasCuSOLVER();
}

View File

@ -6,7 +6,6 @@
#include <c10/util/Half.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/complex.h>
#include <torch/headeronly/core/Dispatch.h>
#ifdef __CUDACC__
#include <cuda.h> // For CUDA_VERSION
@ -62,9 +61,12 @@ TORCH_API void record_kernel_function_dtype(std::string name);
} \
} while (0)
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
case enum_type: { \
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
return __VA_ARGS__(); \
}
#define AT_DISPATCH_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
@ -93,6 +95,14 @@ TORCH_API void record_kernel_function_dtype(std::string name);
return __VA_ARGS__(); \
}
namespace detail {
inline at::ScalarType scalar_type(at::ScalarType s) {
return s;
}
} // namespace detail
// The AT_DISPATCH_* family of macros provides the ability to
// conveniently generate specializations of a kernel over all of the
// dtypes we care about in PyTorch. We call it "dispatch" because
@ -180,13 +190,27 @@ TORCH_API void record_kernel_function_dtype(std::string name);
// but we're just being safe (and it doesn't hurt.) Note we must
// use it to shut up warnings about unused store.
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
THO_DISPATCH_SWITCH_TMPL( \
RECORD_KERNEL_FUNCTION_DTYPE, \
TORCH_CHECK_NOT_IMPLEMENTED, \
TYPE, \
NAME, \
__VA_ARGS__)
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
constexpr const char* at_dispatch_name = NAME; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK_NOT_IMPLEMENTED( \
false, \
'"', \
at_dispatch_name, \
"\" not implemented for '", \
toString(_st), \
"'"); \
} \
C10_DIAGNOSTIC_POP() \
}()
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \

View File

@ -1,8 +1,3 @@
#pragma once
#include <torch/headeronly/core/Dispatch_v2.h>
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
#include <ATen/Dispatch.h>
// This is a new implementation of the AT_DISPATCH macro family from
@ -79,19 +74,41 @@
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
// relied on GPT4 to help me get it right.
// Public API macros
// See documentation above
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
THO_DISPATCH_V2_TMPL( \
AT_DISPATCH_SWITCH, \
AT_DISPATCH_CASE, \
TYPE, \
NAME, \
AT_WRAP(BODY), \
__VA_ARGS__)
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
// This macro lets you pass an arbitrary expression that may contain internal
// commas to another macro without having the commas causing the expression
// to be interpreted as being multiple arguments
#define AT_WRAP(...) __VA_ARGS__
#define AT_FLOAT8_TYPES \
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
#define AT_INTEGRAL_TYPES \
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
#define AT_INTEGRAL_TYPES_V2 \
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
// NB: not *actually* all types
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
#define AT_ALL_TYPES_AND_COMPLEX \
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
// Helper macros
// Unused helper macros, kept for BC:
#define AT_AP_VAR(N, T, ...) \
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
#define AT_CONCAT_AUX(a, b) a##b
#define AT_EXPAND(X) X
// Ensure we never have too many scalar types for the expansion here to
// support. To bump this, you must regenerate the macros below.
@ -102,6 +119,12 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
num_args = 60
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
for i in range(1, num_args+1):
args = ', '.join(f'_{i}' for i in range(1, i+1))
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
@ -112,6 +135,8 @@ for i in range(1, num_args+1):
// Begin generated code
// clang-format off
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)

View File

@ -21,7 +21,6 @@
#if AT_CUDNN_ENABLED()
#include <ATen/cudnn/cudnn-wrapper.h>
#include <cudnn_frontend.h>
#endif
#if AT_MAGMA_ENABLED()
@ -352,26 +351,6 @@ long CUDAHooks::versionCuDNN() const {
#endif
}
long CUDAHooks::versionRuntimeCuDNN() const {
#if AT_CUDNN_ENABLED()
#ifndef USE_STATIC_CUDNN
return cudnnGetVersion();
#else
return CUDNN_VERSION;
#endif
#else
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionCuDNNFrontend() const {
#if AT_CUDNN_ENABLED()
return CUDNN_FRONTEND_VERSION;
#else
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
#endif
}
long CUDAHooks::versionMIOpen() const {
#if AT_ROCM_ENABLED()
return MIOPEN_VERSION_MAJOR * 10000 +

View File

@ -49,8 +49,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCUDART() const override;
long versionCUDART() const override;
long versionCuDNN() const override;
long versionRuntimeCuDNN() const override;
long versionCuDNNFrontend() const override;
long versionMIOpen() const override;
std::string showConfig() const override;
double batchnormMinEpsilonCuDNN() const override;

View File

@ -174,14 +174,6 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionRuntimeCuDNN() const {
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionCuDNNFrontend() const {
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
}
virtual long versionMIOpen() const {
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
}

View File

@ -440,7 +440,7 @@ bool MPSHeapAllocatorImpl::release_cached_buffers() {
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
m_mutex.unlock();
auto stream = getDefaultMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
dispatch_sync(stream->queue(), ^() {
stream->synchronize(SyncType::COMMIT_AND_WAIT);
});
m_mutex.lock();

View File

@ -155,7 +155,4 @@ class TORCH_API MPSStreamImpl {
MPSStreamImpl();
};
#ifdef __OBJC__
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
#endif
} // namespace at::mps

View File

@ -9,6 +9,7 @@
@end
namespace at::mps {
//-----------------------------------------------------------------
// MPSStream
//-----------------------------------------------------------------
@ -152,7 +153,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
if (length == 0) {
return;
}
dispatch_sync_with_rethrow(_serialQueue, ^() {
dispatch_sync(_serialQueue, ^() {
@autoreleasepool {
endKernelCoalescing();
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
@ -182,7 +183,7 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer,
size_t dstOffset,
uint64_t profileId,
SyncType syncType) {
dispatch_sync_with_rethrow(_serialQueue, ^() {
dispatch_sync(_serialQueue, ^() {
@autoreleasepool {
endKernelCoalescing();
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
@ -235,7 +236,7 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
auto& profiler = getMPSProfiler();
const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled();
dispatch_sync_with_rethrow(_serialQueue, ^() {
dispatch_sync(_serialQueue, ^() {
endKernelCoalescing();
if (isGraphProfilingEnabled) {
// this function call is only relevant for interval-based Signposts
@ -288,19 +289,4 @@ MPSStream* getDefaultMPSStream() {
return MPSStreamImpl::getInstance();
}
// Helper methods
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
__block std::optional<std::exception_ptr> block_exception;
dispatch_sync(queue, ^() {
try {
block();
} catch (...) {
block_exception = std::current_exception();
}
});
if (block_exception) {
std::rethrow_exception(*block_exception);
}
}
} // namespace at::mps

View File

@ -409,7 +409,7 @@ struct ConvParams {
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
return false;
}
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
// broken on cuDNN 9.8 - 9.14
if (cudnn_version >= 90800 && cudnn_version < 91500) {
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
@ -453,7 +453,7 @@ struct ConvParams {
}
// native kernel doesn't support 64-bit non-splittable case
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
if (cudnn_version < 0 || cudnn_version > 91000) {

View File

@ -884,69 +884,6 @@ struct type_specialized_kernel_launcher {
}
};
template <int arg_index>
struct type_specialized_broadcast_kernel_launcher {
template <
typename func_t,
typename array_t,
typename dtypes_t,
typename calc_t>
static void apply(
int64_t numel,
func_t f,
array_t data,
dtypes_t dtypes,
calc_t offset_calc) {
using traits = function_traits<func_t>;
using ret_t = typename traits::result_type;
using arg0_t = typename traits::template arg<0>::type;
using arg1_t = typename traits::template arg<1>::type;
if (dtypes[0] == rt_binary_specializations[arg_index][0] &&
dtypes[1] == rt_binary_specializations[arg_index][1] &&
dtypes[2] == rt_binary_specializations[arg_index][2]) {
using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0]>;
using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1]>;
using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2]>;
constexpr int grp_sz = 128;
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
if (unrl) {
auto offsets0 = offset_calc.get(idx);
auto offsets1 = offset_calc.get(idx + grp_sz);
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
void* out0 = data[0] + offsets0[0];
void* out1 = data[0] + offsets1[0];
void* out2 = data[0] + offsets2[0];
void* out3 = data[0] + offsets3[0];
auto u = c10::load<arg0_cpp_t>(data[1] + offsets0[1]);
auto v = c10::load<arg1_cpp_t>(data[2] + offsets0[2]);
ret_t result0 = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
auto u1 = c10::load<arg0_cpp_t>(data[1] + offsets1[1]);
auto v1 = c10::load<arg1_cpp_t>(data[2]+ offsets1[2]);
ret_t result1 = f(c10::convert<arg0_t>(u1), c10::convert<arg1_t>(v1));
auto u2 = c10::load<arg0_cpp_t>(data[1] + offsets2[1]);
auto v2 = c10::load<arg1_cpp_t>(data[2] + offsets2[2]);
ret_t result2 = f(c10::convert<arg0_t>(u2), c10::convert<arg1_t>(v2));
auto u3 = c10::load<arg0_cpp_t>(data[1] + offsets3[1]);
auto v3 = c10::load<arg1_cpp_t>(data[2] + offsets3[2]);
ret_t result3 = f(c10::convert<arg0_t>(u3), c10::convert<arg1_t>(v3));
*(ret_cpp_t*)out0 = c10::convert<ret_cpp_t>(result0);
*(ret_cpp_t*)out1 = c10::convert<ret_cpp_t>(result1);
*(ret_cpp_t*)out2 = c10::convert<ret_cpp_t>(result2);
*(ret_cpp_t*)out3 = c10::convert<ret_cpp_t>(result3);
} else {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
auto u = c10::load<arg0_cpp_t>(data[1] + offsets[1]);
auto v = c10::load<arg1_cpp_t>(data[2] + offsets[2]);
ret_t result = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
*(ret_cpp_t*)out = c10::convert<ret_cpp_t>(result);
}
});
}
}
};
} // namespace
#endif
@ -1065,32 +1002,6 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
}
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
#ifdef USE_ROCM
if (check_binary_rt_types_for_specialization(iter)) {
// constexpr to reduce the amount of kernels generated for
// broadcast elementwise with mexed dtypes and limit which functors are actually
// applied to the load and store at compile time.
using func_tuple = typename traits::ArgsTuple;
if constexpr (
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
check_binary_functor_types_for_specialization<
func_tuple,
float,
float,
traits::arity,
/*arg_num=*/0>::check()) {
memory::detail::static_unroll<
type_specialized_broadcast_kernel_launcher,
rt_binary_specializations.size()>::with_args(
numel,
f,
data,
dtypes,
offset_calc
);
return;
}
}
constexpr int grp_sz = 128;
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
if (unrl) {

View File

@ -133,7 +133,7 @@ at::Tensor quantized_convolution(
// supported in conv.
mask_weight = weight_zero_points.numel() > 1 ? 1 : 0;
if (groups > 1 && weight_zero_points.numel() > 1)
mask_weight = (1 << 0) | (1 << 1); // 2^0 (group) | 2^1 (output channel)
mask_weight = (2 ^ 0) | (2 ^ 1); // 2^0 (group) | 2^1 (output channel)
dnnl::primitive_attr pattr;
bool src_need_zp = (act_zero_point != 0);

View File

@ -40,6 +40,8 @@ using namespace at::mps;
namespace at::native::mps {
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
struct MPSScalar {
id<MTLBuffer> getMTLBuffer() const {
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());

View File

@ -53,6 +53,21 @@
@end
namespace at::native::mps {
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
__block std::optional<std::exception_ptr> block_exception;
dispatch_sync(queue, ^() {
try {
block();
} catch (...) {
block_exception = std::current_exception();
}
});
if (block_exception) {
std::rethrow_exception(*block_exception);
}
}
/**
* Computes distance from lowest to highest element offset in given tensor.
*/

View File

@ -141,9 +141,6 @@ static Tensor& addmv_out_mps_impl(const Tensor& self,
};
MPSStream* stream = at::mps::getCurrentMPSStream();
if (result.numel() == 0) {
return result;
}
Tensor matMulVec = at::mm(mat, vec.unsqueeze(1)).squeeze(1);
@autoreleasepool {

View File

@ -220,7 +220,7 @@ Tensor _embedding_bag_dense_backward_mps(const Tensor& output_grad,
auto num_threads = (params.mode == EmbeddingBagMode::MAX) ? output_grad.numel() : num_indices * params.feature_size;
MPSStream* stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_backward_{}_{}",
@ -273,7 +273,7 @@ Tensor _embedding_bag_per_sample_weights_backward_mps(const Tensor& output_grad,
auto num_threads = num_indices * feature_size;
MPSStream* stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = stream->commandEncoder();
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("embedding_bag_per_sample_weights_backward_{}_{}",

View File

@ -923,7 +923,7 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_mps(const Tensor& input,
MPSStream* stream = getCurrentMPSStream();
TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS");
@autoreleasepool {
dispatch_sync_with_rethrow(stream->queue(), ^() {
mps::dispatch_sync_with_rethrow(stream->queue(), ^() {
// which kernel variant to use based on the normalized axis N size
const int N_READS = 4;
auto metalType = mps::scalarToMetalTypeString(input);

View File

@ -2803,7 +2803,7 @@
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA, MPS, MTIA: floor_divide_out
CPU, CUDA, MPS: floor_divide_out
SparseCPU, SparseCUDA, SparseMPS: floor_divide_out_sparse_zerodim
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor

View File

@ -478,7 +478,7 @@ bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) {
const auto s_k = params.key.sym_size(2);
const auto d_qk = params.query.sym_size(3);
const auto d_v = params.value.sym_size(3);
long cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
long cudnn_version = at::detail::getCUDAHooks().versionCuDNN();
if (cudnn_version < 8903) {
if (debug) {
TORCH_WARN("SDPA fprop requires cudnn 8.9.3 or higher");
@ -709,7 +709,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) {
return false;
#endif
#if defined(CUDNN_VERSION)
static auto cudnn_version = at::detail::getCUDAHooks().versionRuntimeCuDNN();
static auto cudnn_version = cudnnGetVersion();
if (params.dropout > 0.0 && cudnn_version > 91100 && cudnn_version < 91400) {
if (debug) {
TORCH_WARN(CUDNN_VERSION, " cuDNN version does not support droppout in SDPA (9.11 - 9.13).");

View File

@ -1,8 +1,6 @@
#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/macros/Export.h>
#include <optional>
namespace c10 {
@ -17,8 +15,7 @@ struct C10_API AutogradState {
bool inference_mode,
bool fw_grad_mode,
bool multithreading_enabled)
: graph_exec_group_(std::nullopt),
grad_mode_(grad_mode),
: grad_mode_(grad_mode),
inference_mode_(inference_mode),
fw_grad_mode_(fw_grad_mode),
multithreading_enabled_(multithreading_enabled),
@ -44,10 +41,6 @@ struct C10_API AutogradState {
view_replay_enabled_ = view_replay_enabled;
}
void set_graph_exec_group(std::optional<SafePyObject> group) {
graph_exec_group_ = std::move(group);
}
bool get_grad_mode() const {
return grad_mode_;
}
@ -68,12 +61,7 @@ struct C10_API AutogradState {
return view_replay_enabled_;
}
const std::optional<SafePyObject>& get_graph_exec_group() const {
return graph_exec_group_;
}
private:
std::optional<SafePyObject> graph_exec_group_;
bool grad_mode_ : 1;
bool inference_mode_ : 1;
bool fw_grad_mode_ : 1;

View File

@ -66,15 +66,6 @@ def define_targets(rules):
],
)
rules.cc_test(
name = "util/nofatal_test",
srcs = ["util/nofatal_test.cpp"],
deps = [
"//c10/util:base",
"@com_google_googletest//:gtest_main",
],
)
rules.cc_test(
name = "util/ssize_test",
srcs = ["util/ssize_test.cpp"],

View File

@ -1,53 +0,0 @@
#include <gtest/gtest.h>
#include <c10/util/Exception.h>
#include <c10/util/Logging.h>
namespace {
template <typename T>
inline void expectThrowsEq(T&& fn, const char* expected_msg) {
try {
std::forward<T>(fn)();
} catch (const c10::Error& e) {
EXPECT_TRUE(
std::string(e.what_without_backtrace()).find(expected_msg) !=
std::string::npos);
return;
}
ADD_FAILURE() << "Expected to throw exception with message \"" << expected_msg
<< "\" but didn't throw";
}
} // namespace
TEST(NofatalTest, TorchCheckComparisons) {
// quick make sure that no-op works as expected
TORCH_CHECK_EQ(1, 1) << "i am a silly message " << 1;
expectThrowsEq(
[]() { TORCH_CHECK_EQ(1, 2) << "i am a silly message " << 1; },
"Check failed: 1 == 2 (1 vs. 2). i am a silly message 1");
expectThrowsEq(
[]() { TORCH_CHECK_NE(2, 2); }, "Check failed: 2 != 2 (2 vs. 2).");
expectThrowsEq(
[]() { TORCH_CHECK_LT(2, 2); }, "Check failed: 2 < 2 (2 vs. 2).");
expectThrowsEq(
[]() { TORCH_CHECK_LE(3, 2); }, "Check failed: 3 <= 2 (3 vs. 2).");
expectThrowsEq(
[]() { TORCH_CHECK_GT(2, 2); }, "Check failed: 2 > 2 (2 vs. 2).");
expectThrowsEq(
[]() { TORCH_CHECK_GE(2, 3); }, "Check failed: 2 >= 3 (2 vs. 3).");
expectThrowsEq(
[]() {
void* p = nullptr;
TORCH_CHECK_NOTNULL(p);
},
"Check failed: 'p' must be non NULL.");
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
// if dbg build, DCHECK should result in deth
EXPECT_DEATH(TORCH_DCHECK_EQ(1, 2), "Check failed");
#else
TORCH_DCHECK_EQ(1, 2); // no-op
#endif
#endif // GTEST_HAS_DEATH_TEST
}

View File

@ -702,98 +702,6 @@ namespace c10::detail {
#define TORCH_CHECK_ARG(cond, argN, ...) \
TORCH_CHECK(cond, "invalid argument ", argN, ": ", __VA_ARGS__)
#ifndef FATAL_IF
#ifdef C10_USE_GLOG
#define FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger(__FILE__, __LINE__, ::google::GLOG_FATAL) \
.stream()
#else
#define FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL).stream()
#endif
#endif
#ifndef NON_FATAL_IF
#ifdef C10_USE_GLOG
#define NON_FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger( \
__FILE__, __LINE__, ::google::GLOG_FATAL, false) \
.stream()
#else
#define NON_FATAL_IF(condition) \
condition ? (void)0 \
: ::c10::LoggerVoidify() & \
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_FATAL, false) \
.stream()
#endif
#endif
// Binary comparison check macros
#define TORCH_CHECK_OP(val1, val2, op) \
NON_FATAL_IF(((val1)op(val2))) \
<< "Check failed: " #val1 " " #op " " #val2 " (" << (val1) << " vs. " \
<< (val2) << "). "
#define TORCH_DCHECK_OP(val1, val2, op) \
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
<< (val1) << " vs. " << (val2) << "). "
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
// Debug versions of TORCH_CHECK_OP macros
#ifndef NDEBUG
#define TORCH_DCHECK_EQ(val1, val2) TORCH_DCHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) TORCH_DCHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) TORCH_DCHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) TORCH_DCHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) TORCH_DCHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) TORCH_DCHECK_OP(val1, val2, >)
#else // !NDEBUG
// Optimized versions - generate no code
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
TORCH_DCHECK_OP(val1, val2, >)
#endif // NDEBUG
// Null pointer check macro
#define TORCH_CHECK_NOTNULL(val) \
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), false)
#ifndef NDEBUG
#define TORCH_DCHECK_NOTNULL(val) \
::c10::CheckNotNull(__FILE__, __LINE__, #val, (val), true)
#else // !NDEBUG
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
TORCH_CHECK_NOTNULL(val)
#endif // NDEBUG
// ----------------------------------------------------------------------------
// Deprecated macros
// ----------------------------------------------------------------------------

View File

@ -291,32 +291,6 @@ namespace c10 {
using fLB::FLAGS_logtostderr;
using fLI::FLAGS_minloglevel;
using fLI::FLAGS_v;
MessageLogger::MessageLogger(
const char* file,
int line,
int severity,
bool exit_on_fatal)
: stream_(), severity_(severity), exit_on_fatal_(exit_on_fatal) {}
MessageLogger::~MessageLogger() noexcept(false) {
if (severity_ == ::google::GLOG_FATAL) {
DealWithFatal();
}
}
std::stringstream& MessageLogger::stream() {
return stream_;
}
void MessageLogger::DealWithFatal() {
if (exit_on_fatal_) {
LOG(FATAL) << stream_.str();
} else {
throw c10::Error(stream_.str(), nullptr, nullptr);
}
}
} // namespace c10
C10_DEFINE_int(
@ -438,16 +412,17 @@ void ShowLogInfoToStderr() {
FLAGS_caffe2_log_level = GLOG_INFO;
}
MessageLogger::MessageLogger(
const char* file,
int line,
int severity,
bool exit_on_fatal)
: severity_(severity), exit_on_fatal_(exit_on_fatal) {
MessageLogger::MessageLogger(const char* file, int line, int severity)
: severity_(severity) {
if (severity_ < FLAGS_caffe2_log_level) {
// Nothing needs to be logged.
return;
}
#ifdef ANDROID
tag_ = "native";
#else // !ANDROID
tag_ = "";
#endif // ANDROID
time_t rawtime = 0;
time(&rawtime);
@ -483,7 +458,7 @@ MessageLogger::MessageLogger(
}
// Output the contents of the stream to the proper channel on destruction.
MessageLogger::~MessageLogger() noexcept(false) {
MessageLogger::~MessageLogger() {
if (severity_ < FLAGS_caffe2_log_level) {
// Nothing needs to be logged.
return;
@ -523,18 +498,6 @@ MessageLogger::~MessageLogger() noexcept(false) {
}
}
std::stringstream& MessageLogger::stream() {
return stream_;
}
void MessageLogger::DealWithFatal() {
if (exit_on_fatal_) {
abort();
} else {
throw c10::Error(stream_.str(), nullptr, nullptr);
}
}
} // namespace c10
#endif // !C10_USE_GLOG

View File

@ -1,74 +0,0 @@
#ifndef C10_UTIL_LOGGING_COMMON_H_
#define C10_UTIL_LOGGING_COMMON_H_
#include <c10/macros/Export.h>
#include <sstream>
namespace c10 {
// MessageLogger that throws exceptions instead of aborting (glog version)
// or logs and may abort (non-glog version).
class C10_API MessageLogger {
public:
MessageLogger(
const char* file,
int line,
int severity,
bool exit_on_fatal = true);
~MessageLogger() noexcept(false);
// Return the stream associated with the logger object.
std::stringstream& stream();
private:
// When there is a fatal log, and fatal == true, we abort
// otherwise, we throw.
void DealWithFatal();
#if defined(ANDROID) && !defined(C10_USE_GLOG)
const char* tag_{"native"};
#endif
std::stringstream stream_;
int severity_;
bool exit_on_fatal_;
};
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class C10_API LoggerVoidify {
public:
LoggerVoidify() = default;
// This has to be an operator with a precedence lower than << but
// higher than ?:
void operator&(const std::ostream& s [[maybe_unused]]) {}
};
// Forward declarations for CheckNotNull functions
template <typename T>
T& CheckNotNullCommon(
const char* file,
int line,
const char* names,
T& t,
bool fatal = true);
template <typename T>
T* CheckNotNull(
const char* file,
int line,
const char* names,
T* t,
bool fatal = true);
template <typename T>
T& CheckNotNull(
const char* file,
int line,
const char* names,
T& t,
bool fatal = true);
} // namespace c10
#endif // C10_UTIL_LOGGING_COMMON_H_

View File

@ -47,53 +47,57 @@ INSTANTIATE_FOR_CONTAINER(set)
#endif
#include <c10/util/logging_common.h>
#include <glog/logging.h>
namespace c10 {
// Additional macros on top of glog
#define TORCH_CHECK_EQ(val1, val2) CHECK_EQ(val1, val2)
#define TORCH_CHECK_NE(val1, val2) CHECK_NE(val1, val2)
#define TORCH_CHECK_LE(val1, val2) CHECK_LE(val1, val2)
#define TORCH_CHECK_LT(val1, val2) CHECK_LT(val1, val2)
#define TORCH_CHECK_GE(val1, val2) CHECK_GE(val1, val2)
#define TORCH_CHECK_GT(val1, val2) CHECK_GT(val1, val2)
[[noreturn]] void ThrowEnforceNotMet(
const char* file,
const int line,
const char* condition,
const std::string& msg,
const void* caller);
#ifndef NDEBUG
#define TORCH_DCHECK_EQ(val1, val2) DCHECK_EQ(val1, val2)
#define TORCH_DCHECK_NE(val1, val2) DCHECK_NE(val1, val2)
#define TORCH_DCHECK_LE(val1, val2) DCHECK_LE(val1, val2)
#define TORCH_DCHECK_LT(val1, val2) DCHECK_LT(val1, val2)
#define TORCH_DCHECK_GE(val1, val2) DCHECK_GE(val1, val2)
#define TORCH_DCHECK_GT(val1, val2) DCHECK_GT(val1, val2)
#else // !NDEBUG
// These versions generate no code in optimized mode.
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
DCHECK_EQ(val1, val2)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
DCHECK_NE(val1, val2)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
DCHECK_LE(val1, val2)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
DCHECK_LT(val1, val2)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
DCHECK_GE(val1, val2)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
DCHECK_GT(val1, val2)
#endif // NDEBUG
template <typename T>
T& CheckNotNullCommon(
const char* file,
int line,
const char* names,
T& t,
bool fatal) {
if (t == nullptr) {
MessageLogger(file, line, ::google::GLOG_FATAL, fatal).stream()
<< "Check failed: '" << names << "' must be non NULL. ";
}
return t;
}
// Check that a pointer is not null.
#define TORCH_CHECK_NOTNULL(val) CHECK_NOTNULL(val)
template <typename T>
T* CheckNotNull(
const char* file,
int line,
const char* names,
T* t,
bool fatal) {
return CheckNotNullCommon(file, line, names, t, fatal);
}
template <typename T>
T& CheckNotNull(
const char* file,
int line,
const char* names,
T& t,
bool fatal) {
return CheckNotNullCommon(file, line, names, t, fatal);
}
} // namespace c10
#ifndef NDEBUG
// Debug only version of TORCH_CHECK_NOTNULL
#define TORCH_DCHECK_NOTNULL(val) DCHECK_NOTNULL(val)
#else // !NDEBUG
// Optimized version - generates no code.
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
DCHECK_NOTNULL(val)
#endif // NDEBUG
// Log with source location information override (to be used in generic
// warning/error handlers implemented as functions, not macros)

View File

@ -13,7 +13,6 @@
#include <vector>
#include <c10/util/Flags.h>
#include <c10/util/logging_common.h>
const char CAFFE2_SEVERITY_PREFIX[] = "FEWIV";
@ -25,40 +24,61 @@ const int GLOG_ERROR = 2;
const int GLOG_WARNING = 1;
const int GLOG_INFO = 0;
class C10_API MessageLogger {
public:
MessageLogger(const char* file, int line, int severity);
~MessageLogger();
// Return the stream associated with the logger object.
std::stringstream& stream() {
return stream_;
}
private:
// When there is a fatal log, we simply abort.
void DealWithFatal() {
abort();
}
const char* tag_;
std::stringstream stream_;
int severity_;
};
// This class is used to explicitly ignore values in the conditional
// logging macros. This avoids compiler warnings like "value computed
// is not used" and "statement has no effect".
class C10_API LoggerVoidify {
public:
LoggerVoidify() = default;
// This has to be an operator with a precedence lower than << but
// higher than ?:
void operator&(const std::ostream& s [[maybe_unused]]) {}
};
// Log a message and terminate.
template <class T>
void LogMessageFatal(const char* file, int line, const T& message) {
MessageLogger(file, line, GLOG_FATAL).stream() << message;
}
// Helpers for TORCH_CHECK_NOTNULL(). Two are necessary to support both raw
// pointers and smart pointers.
template <typename T>
T& CheckNotNullCommon(
const char* file,
int line,
const char* names,
T& t,
bool fatal) {
T& CheckNotNullCommon(const char* file, int line, const char* names, T& t) {
if (t == nullptr) {
MessageLogger(file, line, GLOG_FATAL, fatal).stream()
<< "Check failed: '" << names << "' must be non NULL. ";
LogMessageFatal(file, line, std::string(names));
}
return t;
}
template <typename T>
T* CheckNotNull(
const char* file,
int line,
const char* names,
T* t,
bool fatal) {
return CheckNotNullCommon(file, line, names, t, fatal);
T* CheckNotNull(const char* file, int line, const char* names, T* t) {
return CheckNotNullCommon(file, line, names, t);
}
template <typename T>
T& CheckNotNull(
const char* file,
int line,
const char* names,
T& t,
bool fatal) {
return CheckNotNullCommon(file, line, names, t, fatal);
T& CheckNotNull(const char* file, int line, const char* names, T& t) {
return CheckNotNullCommon(file, line, names, t);
}
} // namespace c10
@ -116,6 +136,65 @@ static_assert(
::c10::MessageLogger(__FILE__, __LINE__, ::c10::GLOG_##n).stream()
#endif // NDEBUG
#define TORCH_CHECK_OP(val1, val2, op) \
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
<< (val1) << " vs. " << (val2) << ") "
// TORCH_CHECK_OP macro definitions
#define TORCH_CHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_CHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_CHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_CHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_CHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_CHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
#ifndef NDEBUG
// Debug only versions of TORCH_CHECK_OP macros.
#define TORCH_DCHECK_EQ(val1, val2) TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) TORCH_CHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) TORCH_CHECK_OP(val1, val2, >)
#else // !NDEBUG
// These versions generate no code in optimized mode.
#define TORCH_DCHECK_EQ(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, ==)
#define TORCH_DCHECK_NE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, !=)
#define TORCH_DCHECK_LE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, <=)
#define TORCH_DCHECK_LT(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, <)
#define TORCH_DCHECK_GE(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, >=)
#define TORCH_DCHECK_GT(val1, val2) \
while (false) \
TORCH_CHECK_OP(val1, val2, >)
#endif // NDEBUG
// Check that a pointer is not null.
#define TORCH_CHECK_NOTNULL(val) \
::c10::CheckNotNull( \
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
#ifndef NDEBUG
// Debug only version of TORCH_CHECK_NOTNULL
#define TORCH_DCHECK_NOTNULL(val) \
::c10::CheckNotNull( \
__FILE__, __LINE__, "Check failed: '" #val "' Must be non NULL", (val))
#else // !NDEBUG
// Optimized version - generates no code.
#define TORCH_DCHECK_NOTNULL(val) \
while (false) \
TORCH_CHECK_NOTNULL(val)
#endif // NDEBUG
// ---------------------- Support for std objects --------------------------
// These are adapted from glog to support a limited set of logging capability
// for STL objects.

View File

@ -1941,7 +1941,6 @@ if(BUILD_TEST)
foreach(test_src ${Caffe2_XPU_TEST_SRCS})
get_filename_component(test_name ${test_src} NAME_WE)
add_executable(${test_name} "${test_src}")
torch_compile_options(${test_name})
target_link_libraries(${test_name} torch_library gtest_main)
target_include_directories(${test_name} PRIVATE $<INSTALL_INTERFACE:include>)
target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE})

View File

@ -172,9 +172,9 @@ ignore = [
"SIM102", "SIM103", "SIM112", # flake8-simplify code styles
"SIM105", # these ignores are from flake8-simplify. please fix or ignore with commented reason
"SIM108", # SIM108 ignored because we prefer if-else-block instead of ternary expression
"SIM110", # Checks for for loops that can be replaced with a builtin function, like any or all.
"SIM110",
"SIM114", # Combine `if` branches using logical `or` operator
"SIM115", # Checks for cases where files are opened without using a context manager.
"SIM115",
"SIM116", # Disable Use a dictionary instead of consecutive `if` statements
"SIM117",
"SIM118",

View File

@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None:
raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`")
def mirror_inductor_external_kernels() -> None:
"""
Copy external kernels into Inductor so they are importable.
"""
paths = [
(
CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py",
CWD
/ "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py",
),
]
for new_path, orig_path in paths:
# Create the dirs involved in new_path if they don't exist
if not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the files from the orig location to the new location
if orig_path.is_file():
shutil.copyfile(orig_path, new_path)
continue
if orig_path.is_dir():
if new_path.exists():
# copytree fails if the tree exists already, so remove it.
shutil.rmtree(new_path)
shutil.copytree(orig_path, new_path)
continue
raise RuntimeError(
"Check the file paths in `mirror_inductor_external_kernels()`"
)
# ATTENTION: THIS IS AI SLOP
def extract_variant_from_version(version: str) -> str:
"""Extract variant from version string, defaulting to 'cpu'."""
@ -1616,6 +1647,8 @@ def main() -> None:
if RUN_BUILD_DEPS:
build_deps()
mirror_inductor_external_kernels()
(
ext_modules,
cmdclass,
@ -1649,6 +1682,7 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_inductor/kernel/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -208,7 +208,7 @@ class _BaseDataSparsiferTestCase(TestCase):
assert len(sparsifier1.data_groups) == len(sparsifier2.data_groups)
state1 = state_dict1["state"]
for name in state1:
for name in state1.keys():
# compare mask
assert name in sparsifier2.state
assert "mask" in sparsifier2.state[name]

View File

@ -119,7 +119,7 @@ class TestBaseSparsifier(TestCase):
for idx in range(len(sparsifier0.groups)):
mg0 = sparsifier0.groups[idx]
mg1 = sparsifier1.groups[idx]
for key in mg0:
for key in mg0.keys():
assert key in mg1
if key == "module":
# We cannot compare modules as they are different

View File

@ -10,8 +10,6 @@ set(AOTI_ABI_CHECK_TEST_SRCS
${AOTI_ABI_CHECK_TEST_ROOT}/main.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_cast.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dispatch_v2.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp
${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp

View File

@ -1,82 +0,0 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/Dispatch.h>
#include <torch/headeronly/core/Dispatch_v2.h>
// MY_PRIVATE_CHECK_SELECTIVE_BUILD is a prelude to case block. For
// testing, we do nothing:
#define MY_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) /* empty */
#define MY_PRIVATE_CASE_TYPE_USING_HINT(...) \
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
MY_PRIVATE_CHECK_SELECTIVE_BUILD, __VA_ARGS__)
#define MY_DISPATCH_CASE(...) \
THO_DISPATCH_CASE_TMPL(MY_PRIVATE_CASE_TYPE_USING_HINT, __VA_ARGS__)
// MY_RECORD_KERNEL_FUNCTION_DTYPE is a prelude to switch
// statement. For testing, we just avoid unused variable warning:
#define MY_RECORD_KERNEL_FUNCTION_DTYPE(DISPATCHNAME, ENUMTYPE) \
(void)DISPATCHNAME
// MY_CHECK_NOT_IMPLEMENTED is called in switch default block. For
// testing, we count case mismatches:
#define MY_CHECK_NOT_IMPLEMENTED(...) default_count++
#define MY_DISPATCH_SWITCH(...) \
THO_DISPATCH_SWITCH_TMPL( \
MY_RECORD_KERNEL_FUNCTION_DTYPE, MY_CHECK_NOT_IMPLEMENTED, __VA_ARGS__)
// MY_CASE_FUNCTION is called in a case block. For testing, we count
// case matches and ensure that scalar_t/index_t type is defined:
#define MY_CASE_FUNCTION \
[&] { \
count++; \
scalar_t tmp; \
(void)tmp; \
}
#define MY_INDEX_CASE_FUNCTION \
[&] { \
count++; \
index_t tmp; \
(void)tmp; \
}
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
#define MY_DISPATCH_V2(TYPE, NAME, BODY, ...) \
THO_DISPATCH_V2_TMPL( \
MY_DISPATCH_SWITCH, \
MY_DISPATCH_CASE, \
TYPE, \
NAME, \
AT_WRAP(BODY), \
__VA_ARGS__)
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
TEST(TestDispatchV2, NAME) { \
using torch::headeronly::ScalarType; \
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
int8_t total_count = 0; \
int8_t count = 0; \
int8_t default_count = 0; \
for (ScalarType t : \
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
total_count++; \
MY_DISPATCH_V2(t, "test_my_dispatch_v2", MY_CASE_FUNCTION, __VA_ARGS__); \
} \
EXPECT_EQ(count, EXPECTEDCOUNT); \
EXPECT_EQ(default_count + count, total_count); \
}
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
#undef DEFINE_ITEM

View File

@ -1,45 +0,0 @@
#include <gtest/gtest.h>
#include <torch/headeronly/core/Dispatch_v2.h>
#include <torch/headeronly/util/Exception.h>
#define DEFINE_ITEM(TYPE, SCALARTYPE) ScalarType::SCALARTYPE,
#define TEST_DISPATCH_V2(NAME, EXPECTEDCOUNT, ...) \
TEST(TestThoDispatchV2, NAME) { \
using torch::headeronly::ScalarType; \
using torch::headeronly::impl::ScalarTypeToCPPTypeT; \
int8_t total_count = 0; \
int8_t count = 0; \
int8_t default_count = 0; \
for (ScalarType t : \
{AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ITEM)}) { \
total_count++; \
try { \
THO_DISPATCH_V2( \
t, \
"test_tho_dispatch_v2", \
[&] { \
count++; \
scalar_t tmp; \
(void)tmp; \
}, \
__VA_ARGS__); \
} catch (...) { \
default_count++; /* counts mismatches */ \
} \
} \
EXPECT_EQ(count, EXPECTEDCOUNT); \
EXPECT_EQ(default_count + count, total_count); \
}
TEST_DISPATCH_V2(AT_FLOAT8_TYPES_, 5, AT_FLOAT8_TYPES);
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_, 5, AT_INTEGRAL_TYPES);
TEST_DISPATCH_V2(AT_FLOATING_TYPES_, 2, AT_FLOATING_TYPES);
TEST_DISPATCH_V2(AT_BAREBONES_UNSIGNED_TYPES_, 3, AT_BAREBONES_UNSIGNED_TYPES);
TEST_DISPATCH_V2(AT_INTEGRAL_TYPES_V2_, 8, AT_INTEGRAL_TYPES_V2);
TEST_DISPATCH_V2(AT_COMPLEX_TYPES_, 2, AT_COMPLEX_TYPES);
TEST_DISPATCH_V2(AT_QINT_TYPES_, 3, AT_QINT_TYPES);
TEST_DISPATCH_V2(AT_ALL_TYPES_, 7, AT_ALL_TYPES);
TEST_DISPATCH_V2(AT_ALL_TYPES_AND_COMPLEX_, 9, AT_ALL_TYPES_AND_COMPLEX);
#undef DEFINE_ITEM

View File

@ -67,13 +67,13 @@ Tensor sgd_out_of_place(
void boxed_sgd_out_of_place(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = sgd_out_of_place(
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<Tensor>(stack[1]),
float(torch::stable::detail::to<double>(stack[2])),
torch::stable::detail::to<double>(stack[3]),
torch::stable::detail::to<bool>(stack[4]));
to<Tensor>(stack[0]),
to<Tensor>(stack[1]),
float(to<double>(stack[2])),
to<double>(stack[3]),
to<bool>(stack[4]));
stack[0] = torch::stable::detail::from(res);
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY(libtorch_agnostic, m) {
@ -89,8 +89,8 @@ Tensor identity(Tensor t) {
}
void boxed_identity(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = identity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
Tensor res = identity(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -108,14 +108,14 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CPU, m) {
Tensor my_abs(Tensor t) {
const auto num_args = 1;
StableIValue stack[num_args];
stack[0] = torch::stable::detail::from(t);
stack[0] = from(t);
aoti_torch_call_dispatcher("aten::abs", "", stack);
return torch::stable::detail::to<Tensor>(stack[0]);
return to<Tensor>(stack[0]);
}
void boxed_my_abs(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_abs(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
Tensor tensor_res = my_abs(to<Tensor>(stack[0]));
stack[0] = from(tensor_res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -132,21 +132,21 @@ Tensor my_ones_like(Tensor t, StableIValue device) {
auto mf = aoti_torch_memory_format_contiguous_format();
stack[0] = torch::stable::detail::from(t);
stack[1] = torch::stable::detail::from(std::optional(t.scalar_type())); // dtype
stack[2] = torch::stable::detail::from(std::nullopt); // layout
stack[3] = torch::stable::detail::from(std::optional(device)); // device
stack[4] = torch::stable::detail::from(std::optional(false)); // pin_memory
stack[5] = torch::stable::detail::from(std::optional(mf)); // memory_format
stack[0] = from(t);
stack[1] = from(std::optional(t.scalar_type())); // dtype
stack[2] = from(std::nullopt); // layout
stack[3] = from(std::optional(device)); // device
stack[4] = from(std::optional(false)); // pin_memory
stack[5] = from(std::optional(mf)); // memory_format
aoti_torch_call_dispatcher("aten::ones_like", "", stack);
return torch::stable::detail::to<Tensor>(stack[0]);
return to<Tensor>(stack[0]);
}
void boxed_my_ones_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = my_ones_like(torch::stable::detail::to<Tensor>(stack[0]), stack[1]);
stack[0] = torch::stable::detail::from(res);
Tensor res = my_ones_like(to<Tensor>(stack[0]), stack[1]);
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -159,28 +159,28 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
std::tuple<Tensor, Tensor, bool> exp_neg_is_leaf(Tensor t1, Tensor t2, Tensor t3) {
StableIValue stack_exp[1];
stack_exp[0] = torch::stable::detail::from(t1);
stack_exp[0] = from(t1);
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
StableIValue stack_neg[1];
stack_neg[0] = torch::stable::detail::from(t2);
stack_neg[0] = from(t2);
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
StableIValue stack_is_leaf[1];
stack_is_leaf[0] = torch::stable::detail::from(t3);
stack_is_leaf[0] = from(t3);
aoti_torch_call_dispatcher("aten::is_leaf", "", stack_is_leaf);
return std::make_tuple(
torch::stable::detail::to<Tensor>(stack_exp[0]),
torch::stable::detail::to<Tensor>(stack_neg[0]),
torch::stable::detail::to<bool>(stack_is_leaf[0]));
to<Tensor>(stack_exp[0]),
to<Tensor>(stack_neg[0]),
to<bool>(stack_is_leaf[0]));
}
void boxed_exp_neg_is_leaf(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto tuple = exp_neg_is_leaf(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<Tensor>(stack[2]));
stack[0] = torch::stable::detail::from(std::get<0>(tuple));
stack[1] = torch::stable::detail::from(std::get<1>(tuple));
stack[2] = torch::stable::detail::from(std::get<2>(tuple));
auto tuple = exp_neg_is_leaf(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<Tensor>(stack[2]));
stack[0] = from(std::get<0>(tuple));
stack[1] = from(std::get<1>(tuple));
stack[2] = from(std::get<2>(tuple));
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -193,15 +193,15 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
Tensor neg_exp(Tensor t) {
StableIValue stack[1];
stack[0] = torch::stable::detail::from(t);
stack[0] = from(t);
aoti_torch_call_dispatcher("aten::exp", "", stack);
aoti_torch_call_dispatcher("aten::neg", "", stack);
return torch::stable::detail::to<Tensor>(stack[0]);
return to<Tensor>(stack[0]);
}
void boxed_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
Tensor res = neg_exp(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -214,10 +214,10 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
Tensor divide_neg_exp(Tensor t) {
StableIValue stack_neg[1];
stack_neg[0] = torch::stable::detail::from(t);
stack_neg[0] = from(t);
StableIValue stack_exp[1];
stack_exp[0] = torch::stable::detail::from(t);
stack_exp[0] = from(t);
aoti_torch_call_dispatcher("aten::exp", "", stack_exp);
aoti_torch_call_dispatcher("aten::neg", "", stack_neg);
@ -225,12 +225,12 @@ Tensor divide_neg_exp(Tensor t) {
stack_div[0] = stack_neg[0];
stack_div[1] = stack_exp[0];
aoti_torch_call_dispatcher("aten::divide", "Tensor", stack_div);
return torch::stable::detail::to<Tensor>(stack_div[0]);
return to<Tensor>(stack_div[0]);
}
void boxed_divide_neg_exp(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor res = divide_neg_exp(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
Tensor res = divide_neg_exp(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -246,8 +246,8 @@ bool is_contiguous(Tensor t) {
}
void boxed_is_contiguous(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
bool res = is_contiguous(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
bool res = is_contiguous(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -263,9 +263,9 @@ Tensor my_transpose(Tensor t, int64_t dim0, int64_t dim1) {
}
void boxed_my_transpose(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_transpose(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<int64_t>(stack[1]), torch::stable::detail::to<int64_t>(stack[2]));
auto res = my_transpose(to<Tensor>(stack[0]), to<int64_t>(stack[1]), to<int64_t>(stack[2]));
stack[0] = torch::stable::detail::from(res);
stack[0] = from(res);
}
Tensor my_empty_like(Tensor t) {
@ -273,8 +273,8 @@ Tensor my_empty_like(Tensor t) {
}
void boxed_empty_like(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_empty_like(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_empty_like(to<Tensor>(stack[0]));
stack[0] = from(res);
}
bool my_is_cpu(Tensor t) {
@ -283,8 +283,8 @@ bool my_is_cpu(Tensor t) {
void boxed_my_is_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_is_cpu(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_is_cpu(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor fill_infinity(Tensor t) {
@ -296,8 +296,8 @@ void boxed_fill_infinity(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = fill_infinity(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = fill_infinity(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_pad(Tensor t) {
@ -310,8 +310,8 @@ void boxed_my_pad(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_pad(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_pad(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_narrow(Tensor t, int64_t dim, int64_t start, int64_t length) {
@ -323,11 +323,11 @@ void boxed_my_narrow(
uint64_t num_args,
uint64_t num_outputs) {
auto res = my_narrow(
torch::stable::detail::to<Tensor>(stack[0]),
torch::stable::detail::to<int64_t>(stack[1]),
torch::stable::detail::to<int64_t>(stack[2]),
torch::stable::detail::to<int64_t>(stack[3]));
stack[0] = torch::stable::detail::from(res);
to<Tensor>(stack[0]),
to<int64_t>(stack[1]),
to<int64_t>(stack[2]),
to<int64_t>(stack[3]));
stack[0] = from(res);
}
Tensor my_new_empty_dtype_variant(Tensor t) {
@ -342,8 +342,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) {
}
void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_empty_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_new_empty_dtype_variant(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_new_zeros_dtype_variant(Tensor t) {
@ -352,8 +352,8 @@ Tensor my_new_zeros_dtype_variant(Tensor t) {
}
void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_new_zeros_dtype_variant(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_new_zeros_dtype_variant(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
@ -361,8 +361,8 @@ Tensor my_copy_(Tensor dst, Tensor src, bool non_blocking) {
}
void boxed_my_copy_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_copy_(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]), torch::stable::detail::to<bool>(stack[2]));
stack[0] = torch::stable::detail::from(tensor_res);
Tensor tensor_res = my_copy_(to<Tensor>(stack[0]), to<Tensor>(stack[1]), to<bool>(stack[2]));
stack[0] = from(tensor_res);
}
Tensor my_clone(Tensor t) {
@ -370,8 +370,8 @@ Tensor my_clone(Tensor t) {
}
void boxed_my_clone(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
Tensor tensor_res = my_clone(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(tensor_res);
Tensor tensor_res = my_clone(to<Tensor>(stack[0]));
stack[0] = from(tensor_res);
}
@ -408,8 +408,8 @@ Tensor my_zero_(Tensor t) {
}
void boxed_my_zero_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_zero_(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_zero_(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_amax(Tensor t) {
@ -417,8 +417,8 @@ Tensor my_amax(Tensor t) {
}
void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_amax(to<Tensor>(stack[0]));
stack[0] = from(res);
}
Tensor my_amax_vec(Tensor t) {
@ -426,8 +426,8 @@ Tensor my_amax_vec(Tensor t) {
}
void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = my_amax_vec(torch::stable::detail::to<Tensor>(stack[0]));
stack[0] = torch::stable::detail::from(res);
auto res = my_amax_vec(to<Tensor>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -464,8 +464,8 @@ void boxed_test_default_constructor(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
bool res = test_default_constructor(torch::stable::detail::to<bool>(stack[0]));
stack[0] = torch::stable::detail::from(res);
bool res = test_default_constructor(to<bool>(stack[0]));
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -478,56 +478,6 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my_amax_vec", &boxed_my_amax_vec);
}
std::vector<Tensor> my__foreach_mul(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul", "List", stack.data());
return torch::stable::detail::to<std::vector<Tensor>>(stack[0]);
}
void boxed_my__foreach_mul(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
// Why is the following NOT torch::stable::detail::to<HeaderOnlyArrayRef<Tensor>>(stack[0])? Because calling `to`
// on a StableIValue means that the result is owning its underlying data now! HeaderOnlyArrayRef
// is not owning, so it cannot safely steward the result of the torch::stable::detail::to<>.
auto res = my__foreach_mul(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
void my__foreach_mul_(torch::headeronly::HeaderOnlyArrayRef<Tensor> self, torch::headeronly::HeaderOnlyArrayRef<Tensor> other) {
std::array<StableIValue, 2> stack = {torch::stable::detail::from(self), torch::stable::detail::from(other)};
aoti_torch_call_dispatcher("aten::_foreach_mul_", "List", stack.data());
}
void boxed_my__foreach_mul_(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
my__foreach_mul_(torch::stable::detail::to<std::vector<Tensor>>(stack[0]), torch::stable::detail::to<std::vector<Tensor>>(stack[1]));
}
std::vector<Tensor> make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) {
// This function tests that my__foreach_mul can take in std::initializer_lists
// in addition to std::vectors.
Tensor t1_1 = my_clone(t1);
Tensor t1_2 = my_clone(t1);
Tensor t2_1 = my_clone(t2);
Tensor t2_2 = my_clone(t2);
return my__foreach_mul({t1_1, t2_1}, {t1_2, t2_2});
}
void boxed_make_tensor_clones_and_call_foreach(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
auto res = make_tensor_clones_and_call_foreach(torch::stable::detail::to<Tensor>(stack[0]), torch::stable::detail::to<Tensor>(stack[1]));
stack[0] = torch::stable::detail::from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
m.def("my__foreach_mul(Tensor[] self, Tensor[] other) -> Tensor[]");
m.def("my__foreach_mul_(Tensor(a!)[] self, Tensor[] other) -> ()");
m.def("make_tensor_clones_and_call_foreach(Tensor t1, Tensor t2) -> Tensor[]");
}
STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("my__foreach_mul", &boxed_my__foreach_mul);
m.impl("my__foreach_mul_", &boxed_my__foreach_mul_);
m.impl("make_tensor_clones_and_call_foreach", &boxed_make_tensor_clones_and_call_foreach);
}
// Test functions for torch::stable::accelerator APIs
#ifdef LAE_USE_CUDA
@ -550,8 +500,8 @@ void boxed_test_device_guard(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int res = test_device_guard(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
int res = test_device_guard(static_cast<int64_t>(to<int64_t>(stack[0])));
stack[0] = from(res);
}
int64_t test_device_guard_set_index() {
@ -570,7 +520,7 @@ void boxed_test_device_guard_set_index(
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_device_guard_set_index();
stack[0] = torch::stable::detail::from(res);
stack[0] = from(res);
}
int64_t test_stream(int32_t device_index) {
@ -586,8 +536,8 @@ void boxed_test_stream(
StableIValue* stack,
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_stream(static_cast<int64_t>(torch::stable::detail::to<int64_t>(stack[0])));
stack[0] = torch::stable::detail::from(res);
int64_t res = test_stream(static_cast<int64_t>(to<int64_t>(stack[0])));
stack[0] = from(res);
}
int64_t test_get_current_device_index() {
@ -599,7 +549,7 @@ void boxed_test_get_current_device_index(
uint64_t num_args,
uint64_t num_outputs) {
int64_t res = test_get_current_device_index();
stack[0] = torch::stable::detail::from(res);
stack[0] = from(res);
}
STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) {
@ -615,5 +565,4 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) {
m.impl("test_stream", &boxed_test_stream);
m.impl("test_get_current_device_index", &boxed_test_get_current_device_index);
}
#endif // LAE_USE_CUDA

View File

@ -333,45 +333,3 @@ def my_new_zeros_dtype_variant(t) -> Tensor:
Returns: New zeros tensor
"""
return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t)
def my__foreach_mul_(tensors, others) -> ():
"""
Updates tensors to be the result of pointwise multiplying with others.
Args:
tensors: list of tensors
others: list of tensors (with the same corresponding shapes as tensors)
Returns: nothing, tensors is updated in place.
"""
torch.ops.libtorch_agnostic.my__foreach_mul_.default(tensors, others)
def my__foreach_mul(tensors, others) -> list[Tensor]:
"""
Returns a list of tensors that are the results of pointwise multiplying
tensors and others.
Args:
tensors: list of tensors
others: list of tensors (with the same corresponding shapes as tensors)
Returns: list of multiplied tensors
"""
return torch.ops.libtorch_agnostic.my__foreach_mul.default(tensors, others)
def make_tensor_clones_and_call_foreach(t1, t2) -> list[Tensor]:
"""
Returns a list of 2 tensors corresponding to the square of the inputs.
Args:
t1: Tensor
t2: Tensor
Returns: list of [t1^2, t2^2]
"""
return torch.ops.libtorch_agnostic.make_tensor_clones_and_call_foreach.default(
t1, t2
)

View File

@ -367,57 +367,6 @@ if not IS_WINDOWS:
self.assertNotEqual(result.data_ptr(), expected.data_ptr())
self.assertEqual(result.stride(), expected.stride())
def test_my__foreach_mul_(self, device):
import libtorch_agnostic
N = 5
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
tensors_c = [t.clone() for t in tensors]
others = [torch.rand(32, 16, device=device) for _ in range(N)]
libtorch_agnostic.ops.my__foreach_mul_(tensors, others)
expected_values = torch._foreach_mul(tensors_c, others)
for tensor_t, expected_t in zip(tensors, expected_values):
self.assertEqual(tensor_t, expected_t)
def test_my__foreach_mul(self, device):
import libtorch_agnostic
N = 5
tensors = [torch.rand(32, 16, device=device) for _ in range(N)]
others = [torch.rand(32, 16, device=device) for _ in range(N)]
result = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
expected = torch._foreach_mul(tensors, others)
for result_t, expected_t in zip(result, expected):
self.assertEqual(result_t, expected_t)
def _make_cuda_tensors(prior_mem):
cuda_res = libtorch_agnostic.ops.my__foreach_mul(tensors, others)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
expected = torch._foreach_mul(tensors, others)
for result_t, expected_t in zip(cuda_res, expected):
self.assertEqual(result_t, expected_t)
if tensors[0].is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
def test_make_tensor_clones_and_call_foreach(self, device):
import libtorch_agnostic
t1 = torch.rand(2, 5, device=device)
t2 = torch.rand(3, 4, device=device)
result = libtorch_agnostic.ops.make_tensor_clones_and_call_foreach(t1, t2)
self.assertEqual(result[0], t1 * t1)
self.assertEqual(result[1], t2 * t2)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":

View File

@ -76,7 +76,7 @@ class ReplicateTest(MultiProcessTestCase):
store=dist.FileStore(self.file_name, self.world_size),
)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_transformer(self):
"""
This tests that replicate works on a transformer model with fully_shard and replicate layers
@ -126,7 +126,7 @@ class ReplicateTest(MultiProcessTestCase):
for parameter in layer.parameters():
self.assertEqual(parameter.placements, (Shard(dim=0),))
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_transformer_managed_modules(self):
"""
This tests that replicate managed modules works properly. In this test we use a Transformer Module with 3 layers,
@ -178,7 +178,7 @@ class ReplicateTest(MultiProcessTestCase):
replicate_model = replicate(replicate_model)
self.assertEqual(len(_get_managed_modules((replicate_model,))), 21)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_replicate_tp_device_mesh(self):
"""
This tests that a user can pass in a device mesh to replicate a module
@ -206,7 +206,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(parameter.device_mesh.shape, (2,))
self.assertEqual(parameter.placements, (Replicate(),))
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_train_replicate_fsdp(self):
"""
Tests that replicate_model has the same behavior as original model when training
@ -253,7 +253,7 @@ class ReplicateTest(MultiProcessTestCase):
self.assertEqual(replicate_loss, loss)
check_sharded_parity(self, model, replicate_model)
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_train_parity_2d_mlp(self):
"""
Verifies when a device mesh is passed in, the model has the same behavior as the original model when training

View File

@ -80,7 +80,7 @@ class TestSACILP(TestCase):
# postprocessing due to the fact that for ModTracker, the post backward hook
# is not being called for modules whose inputs don't require gradients
# TODO: fix this in ModTracker and ensure it does not lead to any perf regression
if _ModState.POST_BW not in mod_stats.snapshots:
if _ModState.POST_BW not in mod_stats.snapshots.keys():
mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append(
copy.deepcopy(last_snapshot)
)

View File

@ -16,7 +16,7 @@ from torch.distributed.argparse_util import check_env, env
class ArgParseUtilTest(unittest.TestCase):
def setUp(self):
# remove any lingering environment variables
for e in os.environ.keys(): # noqa: SIM118
for e in os.environ.keys():
if e.startswith("PET_"):
del os.environ[e]

View File

@ -207,7 +207,7 @@ class TestDefaultStager(TestCase):
for i, result in enumerate(staged_results):
self.assertIsInstance(result, dict)
# Verify the result contains the expected keys
for key in state_dicts[i]:
for key in state_dicts[i].keys():
self.assertIn(key, result)
stager.close()

View File

@ -299,7 +299,7 @@ class TestDTensorReshardMeshChange(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_dtensor_checkpoint_with_uneven_shards(self) -> None:
"""
Saving a dtensor with uneven shards.
@ -436,7 +436,6 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_checkpointable_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.
@ -499,7 +498,6 @@ class TestCheckpointableReshard(DTensorTestBase):
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_uneven_reshard_with_dtensor_shards_wrapper_api(self) -> None:
"""
Saves a 1d distributed tensor that has shards with uneven sizes using Checkpointable API.

View File

@ -60,7 +60,7 @@ class TestSingleRankSaveLoad(TestCase):
self.assertEqual(
sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())
)
for key in state_dict_to_save:
for key in state_dict_to_save.keys():
self.assertTrue(
torch.equal(state_dict_to_save[key], state_dict_loaded[key])
)
@ -89,7 +89,7 @@ class TestSingleRankSaveLoad(TestCase):
self.assertEqual(
sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())
)
for key in state_dict_to_save:
for key in state_dict_to_save.keys():
self.assertTrue(
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
)
@ -116,7 +116,7 @@ class TestSingleRankSaveLoad(TestCase):
self.assertEqual(
sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())
)
for key in state_dict_to_save:
for key in state_dict_to_save.keys():
self.assertTrue(
torch.equal(state_dict_to_save[key], state_dict_loaded[key])
)
@ -156,7 +156,7 @@ class TestSingleRankSaveLoad(TestCase):
self.assertEqual(
sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())
)
for key in state_dict_to_save:
for key in state_dict_to_save.keys():
self.assertTrue(
torch.equal(state_dict_to_save[key], state_dict_to_load[key])
)

View File

@ -769,7 +769,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
model_state_dict3 = copy.deepcopy(model_state_dict3)
self.assertEqual(len(model_state_dict2), 2)
self.assertEqual(len(model_state_dict3), 2)
for key in model_state_dict3:
for key in model_state_dict3.keys():
full_fqn = f"l.{key}"
value1 = model_state_dict1[full_fqn]
value2 = model_state_dict2[full_fqn]
@ -886,7 +886,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
self.assertEqual(cpu_model_value, meta_model_value)
@with_comms
@skip_if_lt_x_gpu(4)
@skip_if_lt_x_gpu(2)
def test_setting_meta_device_model_broadcasting_and_memory(self) -> None:
# This test verifies that we can set model state dict by a meta device model
# With the correlated changes in state_dict, meta device model should be accepted

View File

@ -479,7 +479,6 @@ class TestFSDPMiscMultiProcess(FSDPTest):
for (n, p), (n_prev, p_prev) in zip(
fsdp_overlap.named_parameters(), fsdp_overlap_prev_params
):
self.assertEqual(n, n_prev)
self.assertNotEqual(
p,
p_prev,

View File

@ -587,7 +587,9 @@ class TestFSDPStateDict(FSDPTest):
model, cpu_offload.offload_params, fp16
)
ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k]
ignore_keys = [
k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k
]
self._validate_state_dict_contents(
model,
@ -908,7 +910,7 @@ class TestFSDPStateDict(FSDPTest):
with sd_mgr:
fsdp_state_dict = model.state_dict()
ignore_keys = [k for k in fsdp_state_dict if NON_ROOT_FSDP_PREFIX in k]
ignore_keys = [k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k]
self._validate_state_dict_contents(
model,
fsdp_state_dict,
@ -957,7 +959,9 @@ class TestFSDPStateDict(FSDPTest):
# Full name of linear_skip param tensors in SkipModel, as would be
# stored in checkpoint.
linear_skip_tensor_names = [
k for k in dict(module.named_parameters()) if LINEAR_SKIP in k
k
for k in dict(module.named_parameters()).keys()
if LINEAR_SKIP in k
]
# skip SkipModule
linear_skip = getattr(module, LINEAR_SKIP)

View File

@ -137,7 +137,7 @@ class ElasticLaunchTest(unittest.TestCase):
self.test_dir = tempfile.mkdtemp()
# remove any lingering environment variables.
for env in os.environ.keys(): # noqa: SIM118
for env in os.environ.keys():
if env.startswith("PET_"):
del os.environ[env]

View File

@ -69,7 +69,7 @@ class ElasticLaunchTest(TestCase):
self.test_dir = tempfile.mkdtemp()
# remove any lingering environment variables
for env in os.environ.keys(): # noqa: SIM118
for env in os.environ.keys():
if env.startswith("PET_"):
del os.environ[env]

View File

@ -39,7 +39,6 @@ from torch.nn.modules.loss import MSELoss
from torch.testing._internal.common_distributed import (
MultiProcContinuousTest,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import (
check_leaked_tensors,
@ -47,7 +46,6 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_MULTIACCELERATOR,
)
@ -58,6 +56,7 @@ batch_size = 64
torch.manual_seed(0)
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = dist.get_default_backend_for_device(device_type)
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
@dataclass
@ -232,7 +231,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [_ScheduleForwardOnly])
@skip_if_lt_x_gpu(4)
def test_forward_only(self, ScheduleClass):
mod, mod_ref, x, _, _ = setup_models_and_data(self.config)
x_clone = x.clone()
@ -276,7 +274,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_eval_inference_mode(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -354,7 +351,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_return_output(self, ScheduleClass):
num_microbatches = 4
if ScheduleClass in [
@ -410,7 +406,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_multi_iter(self, ScheduleClass):
mod, _, x, target, loss_fn = setup_models_and_data(self.config)
chunks = 4
@ -434,7 +429,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_kwargs_with_tracer(self, ScheduleClass):
mod = ModelWithKwargs(d_hid, splits=self.world_size)
mod.to(self.device)
@ -487,7 +481,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@skip_if_lt_x_gpu(4)
def test_grad_with_tracer(self, ScheduleClass):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -530,7 +523,6 @@ class ScheduleTest(MultiProcContinuousTest):
)
@parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B])
@parametrize("shape_inference", [True, False])
@skip_if_lt_x_gpu(4)
def test_grad_with_manual(self, ScheduleClass, shape_inference):
mod, ref_mod, x, target, loss_fn = setup_models_and_data(self.config)
@ -594,7 +586,6 @@ class ScheduleTest(MultiProcContinuousTest):
ScheduleInterleavedZeroBubble,
],
)
@skip_if_lt_x_gpu(4)
def test_grad_with_manual_interleaved(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -659,7 +650,6 @@ class ScheduleTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble])
@skip_if_lt_x_gpu(4)
def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -746,7 +736,6 @@ class ScheduleTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleZBVZeroBubble, ScheduleDualPipeV],
)
@skip_if_lt_x_gpu(4)
def test_v_shape_schedules(self, schedule_class):
n_stages = 8
rank_stages = {0: [0, 7], 1: [1, 6], 2: [2, 5], 3: [3, 4]}
@ -791,7 +780,6 @@ class ScheduleTest(MultiProcContinuousTest):
@skip_but_pass_in_sandcastle_if(
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@skip_if_lt_x_gpu(4)
def test_custom_function_callback(self):
"""Test the custom function callback functionality with _PipelineScheduleRuntime."""
n_stages = 8
@ -991,7 +979,6 @@ class ScheduleTest(MultiProcContinuousTest):
"ScheduleClass",
[ScheduleInterleavedZeroBubble, ScheduleInterleaved1F1B],
)
@skip_if_lt_x_gpu(4)
def test_zero_bubble_with_model_kwargs(self, ScheduleClass):
stages_per_rank = 2
n_stages = stages_per_rank * self.world_size
@ -1085,7 +1072,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
"schedule_class",
[ScheduleVShaped, ScheduleUnbalanced],
)
@skip_if_lt_x_gpu(4)
def test_non_symmetric_stage_ids(self, schedule_class):
n_stages = schedule_class.n_stages
rank_stages = schedule_class.rank_stages
@ -1135,7 +1121,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithReorderedB])
@skip_if_lt_x_gpu(4)
def test_pipeline_schedule_runtime_custom_sched(self, ScheduleClass):
n_stages = 2
stages_per_rank = 1
@ -1196,7 +1181,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
not TEST_MULTIACCELERATOR, f"{backend} test requires 2+ GPUs"
)
@parametrize("ScheduleClass", [ScheduleWithW])
@skip_if_lt_x_gpu(4)
def test_schedule_with_native_zero_bubble(self, ScheduleClass):
n_stages = ScheduleClass.n_stages
num_microbatches = ScheduleClass.num_microbatches

View File

@ -24,7 +24,6 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
TEST_MULTIACCELERATOR,
)
from torch.utils._pytree import tree_map_only
@ -35,6 +34,7 @@ chunks = 8
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
backend = dist.get_default_backend_for_device(device_type)
TEST_MULTIACCELERATOR = torch.accelerator.device_count() >= 2
torch.manual_seed(0)

View File

@ -535,19 +535,6 @@ class DTensorExportTest(TestCase):
self.assertEqual(fn(z), gm(z)[0])
def test_dtensor_data_dependent_index(self):
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
class Foo(torch.nn.Module):
def forward(self, x, y):
return x[y]
x = torch.randn(10)
y = torch.randint(1, (10,)).bool()
x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()])
y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()])
_dynamo_graph_capture_for_export(Foo())(x_dt, y_dt)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -26,7 +26,6 @@ from torch.distributed.tensor.parallel import (
RowwiseParallel,
SequenceParallel,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
create_local_tensor_test_class,
@ -765,7 +764,6 @@ class DistMathOpsTest(DTensorTestBase):
self.assertEqual(grad1_norm.device_mesh, mesh_y)
@with_comms
@skip_if_lt_x_gpu(4)
def test_foreach_add_different_mesh(self):
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(

View File

@ -577,7 +577,7 @@ class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
self.assertEqual(
comm_mode.get_comm_counts(),
{
torch.ops.c10d_functional.all_gather_into_tensor: self.world_size,
torch.ops.c10d_functional.all_gather_into_tensor: 4,
},
)
expected_cost = [

View File

@ -1,18 +1,11 @@
# Owner(s): ["oncall: distributed"]
import itertools
from contextlib import nullcontext
from typing import Any
import torch
import torch.distributed as dist
from torch.distributed._local_tensor import (
local_tensor_mode,
LocalTensor,
LocalTensorMode,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed.tensor import distribute_tensor, DTensor
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._utils import (
_compute_local_shape_and_global_offset,
@ -21,7 +14,6 @@ from torch.distributed.tensor._utils import (
compute_global_tensor_shape,
compute_local_shape_and_global_offset,
compute_local_tensor_info,
ExplicitRedistributionContext,
)
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.placement_types import (
@ -859,93 +851,5 @@ class Test2DStridedLocalShard(DTensorTestBase):
self.assertEqual(global_tensor, dtensor_2d.full_tensor())
class LocalTensorTestBase(TestCase):
def assertEqual(self, lhs, rhs, **kwargs):
mode = local_tensor_mode()
with nullcontext() if mode is None else mode.disable():
if isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor):
assert isinstance(lhs, LocalTensor) and isinstance(rhs, LocalTensor)
super().assertEqual(lhs._ranks, rhs._ranks)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r],
rhs._local_tensors[r],
lambda m: f"rank {r}: {m}",
)
elif isinstance(lhs, LocalTensor) or isinstance(rhs, LocalTensor):
lhs, rhs = (lhs, rhs) if isinstance(lhs, LocalTensor) else (rhs, lhs)
for r in lhs._ranks:
super().assertEqual(
lhs._local_tensors[r], rhs, lambda m: f"rank {r}: {m}"
)
else:
return super().assertEqual(lhs, rhs, **kwargs)
@property
def world_size(self):
raise NotImplementedError("override world-size in your subclass")
def build_device_mesh(self) -> DeviceMesh:
return init_device_mesh("cpu", (self.world_size,))
def setUp(self):
super().setUp()
torch.distributed.init_process_group(
# TODO: test other ranks too
"fake",
rank=0,
world_size=self.world_size,
)
def tearDown(self):
super().tearDown()
try:
dist.destroy_process_group()
except AssertionError:
pass
class TestExplicitRedistribute(LocalTensorTestBase):
@property
def world_size(self):
return 4
def test_explicit_matmul(self):
with LocalTensorMode(self.world_size):
device_mesh = self.build_device_mesh()
dim = 128
x = torch.randn(8, dim, requires_grad=True)
A = torch.randn(dim, dim, requires_grad=True)
# Prepare DTensors
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Shard(0)])
# implicit redistribute works as usual by default
with CommDebugMode() as comm_mode:
torch.matmul(dx, dA)
self.assertEqual(comm_mode.get_total_counts(), 1)
# explicit redistribute works too
with ExplicitRedistributionContext():
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
torch.matmul(dx, dA)
# explicit redistribute allows manual redistribute
with ExplicitRedistributionContext():
dA_repl = dA.redistribute(device_mesh, [Replicate()])
torch.matmul(dx, dA_repl)
dx = distribute_tensor(x, device_mesh, [Shard(0)])
dA = distribute_tensor(A, device_mesh, [Replicate()])
with ExplicitRedistributionContext():
dY = torch.matmul(dx, dA_repl)
loss = dY.sum()
# we now see the error during backwards
with self.assertRaisesRegex(RuntimeError, "Implicit redistribution"):
loss.backward()
if __name__ == "__main__":
run_tests()

View File

@ -54,7 +54,6 @@ def apply_reordering_and_get_graph(graph, out_li) -> None:
"max_compute_pre_fetch",
"custom_runtime_estimation",
"insert_overlap_deps",
"collective_estimator",
)
for key in config_keys:
if (val := getattr(dist_opts, key)) is not None:
@ -944,50 +943,6 @@ class TestComputeCommReorderingBucketing(TestComputeCommReorderingMultiProc):
correct = func(inputs_a, inputs_b, ranks=ranks)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_collective_benchmarking_with_real_pg(self):
"""Test collective benchmarking with real process group (falls back on fake)."""
def func(a):
# Test all three collective types with 8x8 (power of 2 size = 256 elements = 1024 bytes for fp32)
ar = _functional_collectives.all_reduce(a, "sum", "0")
ag = _functional_collectives.all_gather_tensor(
a, 0, list(range(self.world_size))
)
rs = _functional_collectives.reduce_scatter_tensor(a, "sum", 0, "0")
b = torch.matmul(a, a)
c = torch.matmul(ar, b)
return c.sum() + ag.sum() + rs.sum()
patches = {
**get_patches(),
"aten_distributed_optimizations.collective_estimator": "benchmark",
"aten_distributed_optimizations.custom_runtime_estimation": None, # Remove custom estimation so benchmarking happens
}
with _dynamo_dist_per_rank_init(
self.rank,
self.world_size,
self.backend(device_type),
fake_pg=not at_least_x_gpu(2),
):
inputs = torch.ones(8, 8, dtype=torch.float, device=device_type) + self.rank
with torch._inductor.config.patch(patches):
compiled = torch.compile(func)
out, aten_graph_str = run_and_get_aten_graph(compiled, inputs)
# Verify all three collective types are present
FileCheck().check("all_reduce").check("all_gather").check(
"reduce_scatter"
).run(aten_graph_str)
# Test passes if compilation succeeded with benchmarking enabled
# Cache verification is tricky due to multiprocess test setup
correct = func(inputs)
self.assertTrue(same(out, correct))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@torch._inductor.config.patch(get_bucket_patches())
def test_multidtype_bucketing(self):

View File

@ -485,7 +485,7 @@ elif TEST_XPU:
def exit_if_lt_x_accelerators(x):
if torch.accelerator.is_available():
if torch.accelerator.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
sys.exit(TEST_SKIPS[f"multi-accelerator-{x}"].exit_code)
def with_comms(func=None):

View File

@ -1,6 +1,4 @@
# Owner(s): ["module: dynamo"]
# flake8: noqa: B950
# flake8: noqa: E731
import contextlib
import copy
import functools
@ -17,11 +15,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from functorch.compile import min_cut_rematerialization_partition
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
CompileCounterWithBackend,
normalize_gm,
)
from torch._dynamo.testing import CompileCounterWithBackend
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfHpu
@ -1655,43 +1649,6 @@ Non-primal fwd outputs from model w/o backward hook: {mod_no_hook_fwd_outputs_no
self.assertEqual(opt_fn(x), fn(x))
def test_return_same_element_twice(self):
def gn(x):
y = torch.sin(x)
return y, y
def fn(x):
return torch.utils.checkpoint.checkpoint(gn, x, use_reentrant=True)
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[4, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = True); wrap_body_0 = l_x_ = None
getitem: "f32[4, 4]" = tag_activation_checkpoint[0]
getitem_1: "f32[4, 4]" = tag_activation_checkpoint[1]; tag_activation_checkpoint = None
return (getitem, getitem_1)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[4, 4]"):
y: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (y, y)
""",
)
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_mutation(self):
counter = 0
@ -1715,114 +1672,6 @@ class GraphModule(torch.nn.Module):
# The mutation is not reapplied in the backward because the flag was on.
self.assertEqual(counter, 1)
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_list_mutation(self):
def gn(x, z):
out = x.sin()
z.append(out)
return torch.cos(torch.sin(torch.matmul(x, x) @ x)), out
def fn(x):
z = []
out1, out2 = torch.utils.checkpoint.checkpoint(
gn,
x,
z,
use_reentrant=False,
)
return out1, z[0]
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
@torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True)
def test_nonlocal_list_mutation_hidden(self):
def gn(x, z):
o = torch.matmul(x, x) @ x
out = x.sin()
z.append(out)
return torch.cos(torch.sin(o)), torch.sin(x)
def fn(x):
z = []
outs = torch.utils.checkpoint.checkpoint(
gn,
x,
z,
use_reentrant=False,
)
out1 = outs[0]
# Check that the extra output pytree handling is done properly
out2 = outs[-1]
return out1 + out2, z[0]
x = torch.randn(4, 4, requires_grad=True)
ref = fn(x)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(ref[1], res[1])
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[4, 4]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None
out1: "f32[4, 4]" = tag_activation_checkpoint[0]
out2: "f32[4, 4]" = tag_activation_checkpoint[1]
getitem_4: "f32[4, 4]" = tag_activation_checkpoint[4]; tag_activation_checkpoint = None
add: "f32[4, 4]" = out1 + out2; out1 = out2 = None
return (add, getitem_4)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[4, 4]"):
matmul: "f32[4, 4]" = torch.matmul(l_x_, l_x_)
o: "f32[4, 4]" = matmul @ l_x_
out: "f32[4, 4]" = l_x_.sin()
sin_1: "f32[4, 4]" = torch.sin(o)
child: "f32[4, 4]" = torch.cos(sin_1)
child_1: "f32[4, 4]" = torch.sin(l_x_); l_x_ = None
return (child, child_1, matmul, o, out, sin_1)
""",
)
self.assertExpectedInline(
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[4, 4]"):
mm: "f32[4, 4]" = torch.ops.aten.mm.default(primals_1, primals_1)
mm_1: "f32[4, 4]" = torch.ops.aten.mm.default(mm, primals_1); mm = None
sin: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1)
sin_1: "f32[4, 4]" = torch.ops.aten.sin.default(mm_1); mm_1 = None
cos: "f32[4, 4]" = torch.ops.aten.cos.default(sin_1); sin_1 = None
sin_2: "f32[4, 4]" = torch.ops.aten.sin.default(primals_1)
add: "f32[4, 4]" = torch.ops.aten.add.Tensor(cos, sin_2); cos = sin_2 = None
return (add, sin, primals_1)
""",
)
devices = ["cuda", "hpu"]
instantiate_device_type_tests(

View File

@ -408,9 +408,6 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
self.assertEqual(ref0, res0)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip(
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
)
def test_cuda_event_reconstruct(self):
def fn(x):
e = torch.cuda.Event()
@ -428,9 +425,6 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
self.assertEqual(cnts.op_count, 3)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip(
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
)
def test_cuda_event_across_graph_break(self):
def fn(x):
e = torch.cuda.Event()
@ -452,12 +446,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = opt_fn(x)
self.assertEqual(ref[0], res[0])
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 10)
self.assertEqual(cnts.op_count, 9)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip(
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
)
def test_cuda_event_created_outside_of_graph(self):
user_stream = torch.cuda.Stream()
event = torch.cuda.Event()
@ -487,12 +478,9 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
res = run_iters(func, compile=True)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 4)
self.assertEqual(cnts.op_count, 3)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip(
"Will not support external events for now: https://github.com/pytorch/pytorch/issues/167257"
)
def test_cuda_event_method_create_stream_outside_of_compile(self):
def fn(x, cur_stream, new_stream):
x = torch.mul(x, 1)

View File

@ -2109,89 +2109,6 @@ Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: '
with self.assertRaises(Unsupported):
outer_f2(inp)
def test_disable_recursive_flags(self):
class SimpleLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
for recursive_flag in [True, False]:
model = SimpleModel()
other_model = SimpleModel()
model.forward = torch._dynamo.disable(
model.forward,
recursive=recursive_flag,
)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(model.forward),
recursive_flag,
)
other_model = torch._dynamo.disable(other_model, recursive=recursive_flag)
self.assertEqual(
torch._dynamo.is_dynamo_disable_recursive(
other_model.forward
if isinstance(other_model, torch.nn.Module)
else other_model
),
recursive_flag,
)
# check the model is compilable
torch.compile(model)
torch.compile(other_model)
def test_dynamo_disable_annotations(self):
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("buffer", torch.rand(2, 2))
@torch._dynamo.disable()
def f1(self, x) -> torch.Tensor:
return x + self.buffer + 1
@torch._dynamo.disable()
def f2(self, x) -> torch.Tensor:
return x + self.buffer + 2
def forward(self, x) -> torch.Tensor:
return self.f1(x) + self.f2(x)
model = SimpleModel()
inp = torch.rand(2, 2)
with torch.fx.traceback.preserve_node_meta():
exported_model = torch.export.export(model, (inp,))
graph = exported_model.graph_module.graph
found_f1 = False
found_f2 = False
for node in graph.nodes:
if "custom" in node.meta:
if "_torchdynamo_disable_method" in node.meta["custom"]:
if node.meta["custom"]["_torchdynamo_disable_method"] == "f1":
found_f1 = True
elif node.meta["custom"]["_torchdynamo_disable_method"] == "f2":
found_f2 = True
self.assertTrue(found_f1)
self.assertTrue(found_f2)
model.forward = torch._dynamo.disable(model.forward, recursive=False)
with self.assertRaises(RuntimeError):
exported_model = torch.export.export(model, (inp,))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -422,41 +422,34 @@ from user code:
import optree
@torch.compile(backend="eager")
def fn1(x):
tree = {"a": x, "b": (x - 1, 2 * x)}
sin, cos = optree.tree_transpose_map(
lambda t: (torch.sin(t), torch.cos(t)),
tree,
def fn(x):
d = {"a": 1}
optree.tree_flatten_with_path(d)
return torch.sin(x)
def post_munge(s):
s = re.sub(
r"optree\.\S*\.flatten_with_path",
"optree.<path>.flatten_with_path",
s,
)
return sin, cos
fn1(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 0)
@torch.compile(backend="eager")
def fn2(x):
spec = optree.treespec_deque([])
return spec, x
fn2(torch.randn(4))
self.assertGreaterEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
def post_munge(string):
return re.sub(
r"(optree\.|qualname: )\S*(\.make_from_collection)",
r"\1<path>\2",
string,
r"qualname: \S*flatten_with_path",
"qualname: <path>.flatten_with_path",
s,
)
fn(torch.randn(4))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = next(iter(counters["graph_break"].keys()))
self.assertExpectedInline(
post_munge(first_graph_break),
"""\
Attempted to call function marked as skipped
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.make_from_collection.
Explanation: Dynamo cannot trace optree C/C++ function optree.<path>.flatten_with_path.
Hint: Consider using torch.utils._pytree - https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py
Developer debug context: module: optree._C, qualname: <path>.make_from_collection, skip reason: <missing reason>
Developer debug context: module: optree._C, qualname: <path>.flatten_with_path, skip reason: <missing reason>
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""",
)
@ -1050,7 +1043,7 @@ Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especiall
msg = re.sub(r"line (\d+)", "line N", msg)
msg = re.sub(
r"""(?s)Traceback \(most recent call last\):.*
File "exc.py", line N, in unimplemented
File "exc.py", line N, in unimplemented_v2
raise Unsupported\(msg\)""",
"<Internal traceback>\n",
msg,

View File

@ -3354,7 +3354,7 @@ class GraphModule(torch.nn.Module):
x = torch.randn(2, 4)
y = torch.ones(4)
msg = "hints_wrapper: improper args/kwargs"
msg = "hints_wrapper - key hints not provided"
with self.assertRaisesRegex(RuntimeError, msg):
torch.compile(fn_with_hints, backend=cnt)(x, y)
@ -4516,9 +4516,12 @@ class GraphModule(torch.nn.Module):
model, params, inputs, targets
)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertIn(
"torch.func.functional_call capture is disabled",
next(iter(counters["graph_break"].keys())),
self.assertEqual(
{
"torch.func.functional_call capture is disabled, it can be "
"turned on by setting `torch._dynamo.config.inline_inbuilt_nn_modules=True`": 1,
},
dict(counters["graph_break"]),
)
self.assertEqual(actual, expected)

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: dynamo"]
import unittest
from collections.abc import Sequence
from typing import Any, Callable, Union
from collections.abc import Callable, Sequence
from typing import Any, Union
import torch
import torch._dynamo

View File

@ -988,7 +988,6 @@ exclusions = {
"hierarchical_compile",
"compute_dependencies",
"annotation",
"node_runtime_estimation",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View File

@ -1,5 +1,5 @@
# Owner(s): ["module: dynamo"]
from typing import Callable, NamedTuple, Optional
from typing import NamedTuple, Optional, TYPE_CHECKING
import torch
import torch._dynamo
@ -7,6 +7,10 @@ from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import CompileCounter, same
if TYPE_CHECKING:
from collections.abc import Callable
"""
This is an example of a pure-python version of autograd implemented by
@zdevito. It represents a rather challenging test case for TorchDynamo

View File

@ -67,7 +67,7 @@ class IgnoreLogsTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(counters["graph_break"]), 0)
else:
self.assertIn("moo", printed_output)
self.assertGreater(len(counters["graph_break"]), 0)
self.assertEqual(len(counters["graph_break"]), 1)
class ReorderLogsTests(torch._dynamo.test_case.TestCase):

View File

@ -3,7 +3,6 @@ import functools
import re
import unittest
import weakref
from unittest.mock import patch
import torch
import torch._dynamo.test_case
@ -446,37 +445,6 @@ class GraphModule(torch.nn.Module):
""",
)
@requires_cuda
def test_event_tracing(self):
def fn(x) -> None:
e = torch.Event()
e.record()
x.add_(1)
return x
inp = (torch.ones(2, 2, device="cuda"),)
(
_,
_,
fw_graphs,
_,
) = extract_graph(fn, *inp)
self.assertExpectedInline(
print_graph(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]"):
#
record_event = torch.ops.streams.record_event.default(0, 1); record_event = None
#
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1)
copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = None
return (copy_,)
""",
)
@requires_cuda
def test_run_opcheck_fork_join(self):
from torch._dynamo.variables.streams import fork_stream, join_stream
@ -523,20 +491,6 @@ class <lambda>(torch.nn.Module):
torch.accelerator.set_stream(original_stream)
reset_user_object_tracking()
@requires_cuda
def test_inductor_lowering(self):
with patch("torch._inductor.config.implicit_fallbacks", False):
@torch.compile()
def fn(x):
e = torch.Event()
x += x + 1
e.record()
return x
inp = (torch.ones(2, 2, device="cuda"),)
fn(*inp)
def test_is_marked_side_effectful(self):
self.assertIn(
torch.ops.streams.fork.default, torch.fx.node._side_effectful_functions

View File

@ -742,14 +742,11 @@ class TestExport(TestCase):
self.assertExpectedInline(
str(custom_metadata),
"""\
('placeholder', 'x', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('placeholder', 'y', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})
('call_function', 'cat', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'item', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'ge_1', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', '_assert_scalar_default', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('call_function', 'mul', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace', 'moo': 0})
('output', 'output', {'_torchdynamo_disable': True, '_torchdynamo_disable_recursive': True, '_torchdynamo_disable_method': 'dispatch_trace'})""",
('call_function', 'cat', {'moo': 0})
('call_function', 'item', {'moo': 0})
('call_function', 'ge_1', {'moo': 0})
('call_function', '_assert_scalar_default', {'moo': 0})
('call_function', 'mul', {'moo': 0})""",
)
@requires_gpu
@ -1224,14 +1221,8 @@ graph():
%p_block_linear2_bias : [num_users=1] = placeholder[target=p_block_linear2_bias]
%x : [num_users=1] = placeholder[target=x]
%wrap_body0 : [num_users=1] = get_attr[target=wrap_body0]
%tag_activation_checkpoint : [num_users=7] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {})
%tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body0, %x, %p_block_linear1_weight, %p_block_linear1_bias, %p_block_linear2_weight, %p_block_linear2_bias), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {})
%getitem_1 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 1), kwargs = {})
%getitem_2 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 2), kwargs = {})
%getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 3), kwargs = {})
%getitem_4 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 4), kwargs = {})
%getitem_5 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 5), kwargs = {})
%getitem_6 : [num_users=0] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 6), kwargs = {})
return (getitem,)""",
)
@ -1240,14 +1231,14 @@ graph():
"""\
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
%arg2_1 : [num_users=2] = placeholder[target=arg2_1]
%arg3_1 : [num_users=2] = placeholder[target=arg3_1]
%arg4_1 : [num_users=2] = placeholder[target=arg4_1]
%linear : [num_users=2] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {})
%relu : [num_users=2] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {})
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%arg3_1 : [num_users=1] = placeholder[target=arg3_1]
%arg4_1 : [num_users=1] = placeholder[target=arg4_1]
%linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg0_1, %arg1_1, %arg2_1), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%linear,), kwargs = {})
%linear_1 : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%relu, %arg3_1, %arg4_1), kwargs = {})
return (linear_1, arg1_1, arg2_1, linear, relu, arg3_1, arg4_1)""",
return (linear_1,)""",
)
stack = contextlib.ExitStack()

View File

@ -4,7 +4,6 @@ from unittest.mock import patch
import torch
from torch._dynamo.utils import counters
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
@ -40,56 +39,6 @@ class TestHopPrint(TestCase):
self.assertEqual(printed_output, "moo 1 2")
fx_f = make_fx(f)(x)
new_inp = torch.randn(3, 3)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
fx_f(new_inp)
ori_printed_output = mock_stdout.getvalue().strip()
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
f(new_inp)
fx_printed_output = mock_stdout.getvalue().strip()
self.assertEqual(ori_printed_output, fx_printed_output)
def test_print_with_proxy_graph(self):
class M(torch.nn.Module):
def forward(self, x):
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
torch._higher_order_ops.print("moo {x}", x=x)
res = x + x
torch._higher_order_ops.print("moo {x} {y}", x=1, y=2)
torch._higher_order_ops.print("yeehop {x}", x=x.shape[0])
return (res,)
inputs = (torch.randn(3),)
# Without functionalization, print should just appear in the graph directly
gm = make_fx(M(), tracing_mode="symbolic")(*inputs)
self.assertExpectedInline(
str(gm.code).strip(),
"""\
def forward(self, arg0_1):
print_1 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_1 = None
print_2 = torch.ops.higher_order.print('moo {x}', x = arg0_1); print_2 = None
add = torch.ops.aten.add.Tensor(arg0_1, arg0_1)
print_3 = torch.ops.higher_order.print('moo {x} {y}', x = 1, y = 2); print_3 = None
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0); arg0_1 = None
print_4 = torch.ops.higher_order.print('yeehop {x}', x = sym_size_int); sym_size_int = print_4 = None
return (add,)""",
)
new_inp = torch.randn(4)
with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout:
gm(
new_inp,
)
printed_output = mock_stdout.getvalue().strip()
self.assertEqual(printed_output, f"moo 1 2\nmoo {new_inp}\nmoo 1 2\nyeehop 4")
if __name__ == "__main__":
run_tests()

View File

@ -475,14 +475,17 @@ class TestFxGraphCache(TestCase):
if device == GPU_TYPE and not HAS_GPU:
raise unittest.SkipTest(f"requires {GPU_TYPE}")
if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
if (
device == "cuda"
and torch.version.hip is None
and dtype == torch.bfloat16
and not SM80OrLater
):
raise unittest.SkipTest("requires SM80 or later")
if use_static_cuda_launcher and not (device == "cuda" and bundle_triton):
raise unittest.SkipTest(
"Static cuda launcher requires cuda and triton bundling"
)
if use_static_cuda_launcher and TEST_WITH_ROCM:
raise unittest.SkipTest("Static cuda launcher doesn't work with ROCM")
def fn(x, y):
return (x * 2, y @ y)

Some files were not shown because too many files have changed in this diff Show More