mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	Compare commits
	
		
			2 Commits
		
	
	
		
			replace-py
			...
			codex/fix-
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| f00a1b0349 | |||
| bc67bce2e5 | 
@ -438,7 +438,9 @@ def build_torchvision(
 | 
			
		||||
        )
 | 
			
		||||
        build_vars += f"BUILD_VERSION={version}.dev{build_date}"
 | 
			
		||||
    elif build_version is not None:
 | 
			
		||||
        build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
 | 
			
		||||
        build_vars += (
 | 
			
		||||
            f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
 | 
			
		||||
        )
 | 
			
		||||
    if host.using_docker():
 | 
			
		||||
        build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
 | 
			
		||||
 | 
			
		||||
@ -493,7 +495,9 @@ def build_torchdata(
 | 
			
		||||
        )
 | 
			
		||||
        build_vars += f"BUILD_VERSION={version}.dev{build_date}"
 | 
			
		||||
    elif build_version is not None:
 | 
			
		||||
        build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
 | 
			
		||||
        build_vars += (
 | 
			
		||||
            f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
 | 
			
		||||
        )
 | 
			
		||||
    if host.using_docker():
 | 
			
		||||
        build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
 | 
			
		||||
 | 
			
		||||
@ -549,7 +553,9 @@ def build_torchtext(
 | 
			
		||||
        )
 | 
			
		||||
        build_vars += f"BUILD_VERSION={version}.dev{build_date}"
 | 
			
		||||
    elif build_version is not None:
 | 
			
		||||
        build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
 | 
			
		||||
        build_vars += (
 | 
			
		||||
            f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
 | 
			
		||||
        )
 | 
			
		||||
    if host.using_docker():
 | 
			
		||||
        build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
 | 
			
		||||
 | 
			
		||||
@ -607,7 +613,9 @@ def build_torchaudio(
 | 
			
		||||
        )
 | 
			
		||||
        build_vars += f"BUILD_VERSION={version}.dev{build_date}"
 | 
			
		||||
    elif build_version is not None:
 | 
			
		||||
        build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
 | 
			
		||||
        build_vars += (
 | 
			
		||||
            f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
 | 
			
		||||
        )
 | 
			
		||||
    if host.using_docker():
 | 
			
		||||
        build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -176,7 +176,7 @@ case "$tag" in
 | 
			
		||||
    VISION=yes
 | 
			
		||||
    TRITON=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
 | 
			
		||||
  pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3)
 | 
			
		||||
    if [[ $tag =~ "jammy" ]]; then
 | 
			
		||||
      ANACONDA_PYTHON_VERSION=3.10
 | 
			
		||||
    else
 | 
			
		||||
@ -190,9 +190,7 @@ case "$tag" in
 | 
			
		||||
    KATEX=yes
 | 
			
		||||
    UCX_COMMIT=${_UCX_COMMIT}
 | 
			
		||||
    UCC_COMMIT=${_UCC_COMMIT}
 | 
			
		||||
    if [[ $tag =~ "benchmarks" ]]; then
 | 
			
		||||
      INDUCTOR_BENCHMARKS=yes
 | 
			
		||||
    fi
 | 
			
		||||
    INDUCTOR_BENCHMARKS=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-noble-rocm-alpha-py3)
 | 
			
		||||
    ANACONDA_PYTHON_VERSION=3.12
 | 
			
		||||
@ -204,6 +202,7 @@ case "$tag" in
 | 
			
		||||
    KATEX=yes
 | 
			
		||||
    UCX_COMMIT=${_UCX_COMMIT}
 | 
			
		||||
    UCC_COMMIT=${_UCC_COMMIT}
 | 
			
		||||
    INDUCTOR_BENCHMARKS=yes
 | 
			
		||||
    PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950"
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-jammy-xpu-2025.0-py3)
 | 
			
		||||
 | 
			
		||||
@ -1 +1 @@
 | 
			
		||||
v4.54.0
 | 
			
		||||
243e186efbf7fb93328dd6b34927a4e8c8f24395
 | 
			
		||||
 | 
			
		||||
@ -66,9 +66,8 @@ function do_cpython_build {
 | 
			
		||||
        ln -s pip3 ${prefix}/bin/pip
 | 
			
		||||
    fi
 | 
			
		||||
    # install setuptools since python 3.12 is required to use distutils
 | 
			
		||||
    # packaging is needed to create symlink since wheel no longer provides needed information
 | 
			
		||||
    ${prefix}/bin/pip install packaging==25.0 wheel==0.45.1 setuptools==80.9.0
 | 
			
		||||
    local abi_tag=$(${prefix}/bin/python -c "from packaging.tags import interpreter_name, interpreter_version; import sysconfig ; from sysconfig import get_config_var; print('{0}{1}-{0}{1}{2}'.format(interpreter_name(), interpreter_version(), 't' if sysconfig.get_config_var('Py_GIL_DISABLED') else ''))")
 | 
			
		||||
    ${prefix}/bin/pip install wheel==0.45.1 setuptools==80.9.0
 | 
			
		||||
    local abi_tag=$(${prefix}/bin/python -c "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag; print('{0}{1}-{2}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag()))")
 | 
			
		||||
    ln -sf ${prefix} /opt/python/${abi_tag}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,15 +26,15 @@ function install_torchbench() {
 | 
			
		||||
 | 
			
		||||
  python install.py --continue_on_fail
 | 
			
		||||
 | 
			
		||||
  # soxr comes from https://github.com/huggingface/transformers/pull/39429
 | 
			
		||||
  pip install transformers==4.54.0 soxr==0.5.0
 | 
			
		||||
  # TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488
 | 
			
		||||
  # is regressing speedup metric. This needs to be investigated further
 | 
			
		||||
  pip install transformers==4.38.1
 | 
			
		||||
 | 
			
		||||
  echo "Print all dependencies after TorchBench is installed"
 | 
			
		||||
  python -mpip freeze
 | 
			
		||||
  popd
 | 
			
		||||
 | 
			
		||||
  chown -R jenkins torchbench
 | 
			
		||||
  chown -R jenkins /opt/conda
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Pango is needed for weasyprint which is needed for doctr
 | 
			
		||||
@ -48,4 +48,4 @@ install_huggingface
 | 
			
		||||
install_timm
 | 
			
		||||
 | 
			
		||||
# Clean up
 | 
			
		||||
conda_run pip uninstall -y torch torchvision torchaudio triton torchao
 | 
			
		||||
conda_run pip uninstall -y torch torchvision torchaudio triton
 | 
			
		||||
 | 
			
		||||
@ -34,27 +34,18 @@ function install_ubuntu() {
 | 
			
		||||
 | 
			
		||||
    # The xpu-smi packages
 | 
			
		||||
    apt-get install -y flex bison xpu-smi
 | 
			
		||||
 | 
			
		||||
    if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
 | 
			
		||||
        # Compute and Media Runtimes
 | 
			
		||||
        apt-get install -y \
 | 
			
		||||
            intel-opencl-icd intel-level-zero-gpu level-zero \
 | 
			
		||||
            intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
 | 
			
		||||
            libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
 | 
			
		||||
            libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
 | 
			
		||||
            mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
 | 
			
		||||
        # Development Packages
 | 
			
		||||
        apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
 | 
			
		||||
    else # rolling driver
 | 
			
		||||
        apt-get install -y \
 | 
			
		||||
            intel-opencl-icd libze-intel-gpu1 libze1 \
 | 
			
		||||
            intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
 | 
			
		||||
            libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
 | 
			
		||||
            libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
 | 
			
		||||
            mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
 | 
			
		||||
        apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
 | 
			
		||||
    # Compute and Media Runtimes
 | 
			
		||||
    apt-get install -y \
 | 
			
		||||
        intel-opencl-icd intel-level-zero-gpu level-zero \
 | 
			
		||||
        intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
 | 
			
		||||
        libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
 | 
			
		||||
        libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
 | 
			
		||||
        mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
 | 
			
		||||
    if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then
 | 
			
		||||
        apt-get install -y intel-ocloc
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
    # Development Packages
 | 
			
		||||
    apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
 | 
			
		||||
    # Install Intel Support Packages
 | 
			
		||||
    apt-get install -y ${XPU_PACKAGES}
 | 
			
		||||
 | 
			
		||||
@ -139,11 +130,11 @@ function install_sles() {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# Default use GPU driver rolling releases
 | 
			
		||||
XPU_DRIVER_VERSION=""
 | 
			
		||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
 | 
			
		||||
    # Use GPU driver LTS releases
 | 
			
		||||
    XPU_DRIVER_VERSION="/lts/2350"
 | 
			
		||||
# Default use GPU driver LTS releases
 | 
			
		||||
XPU_DRIVER_VERSION="/lts/2350"
 | 
			
		||||
if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then
 | 
			
		||||
    # Use GPU driver rolling releases
 | 
			
		||||
    XPU_DRIVER_VERSION=""
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# Default use Intel® oneAPI Deep Learning Essentials 2025.0
 | 
			
		||||
 | 
			
		||||
@ -63,12 +63,11 @@ lark==0.12.0
 | 
			
		||||
#Pinned versions: 0.12.0
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
librosa>=0.6.2 ; python_version < "3.11" and platform_machine != "s390x"
 | 
			
		||||
librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x"
 | 
			
		||||
librosa>=0.6.2 ; python_version < "3.11"
 | 
			
		||||
librosa==0.10.2 ; python_version == "3.12"
 | 
			
		||||
#Description: A python package for music and audio analysis
 | 
			
		||||
#Pinned versions: >=0.6.2
 | 
			
		||||
#test that import: test_spectral_ops.py
 | 
			
		||||
#librosa depends on numba; disable it for s390x while numba is disabled too
 | 
			
		||||
 | 
			
		||||
#mkl #this breaks linux-bionic-rocm4.5-py3.7
 | 
			
		||||
#Description: Intel oneAPI Math Kernel Library
 | 
			
		||||
@ -111,15 +110,14 @@ ninja==1.11.1.3
 | 
			
		||||
#Pinned versions: 1.11.1.3
 | 
			
		||||
#test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py
 | 
			
		||||
 | 
			
		||||
numba==0.49.0 ; python_version < "3.9" and platform_machine != "s390x"
 | 
			
		||||
numba==0.55.2 ; python_version == "3.9" and platform_machine != "s390x"
 | 
			
		||||
numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x"
 | 
			
		||||
numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
 | 
			
		||||
numba==0.49.0 ; python_version < "3.9"
 | 
			
		||||
numba==0.55.2 ; python_version == "3.9"
 | 
			
		||||
numba==0.55.2 ; python_version == "3.10"
 | 
			
		||||
numba==0.60.0 ; python_version == "3.12"
 | 
			
		||||
#Description: Just-In-Time Compiler for Numerical Functions
 | 
			
		||||
#Pinned versions: 0.54.1, 0.49.0, <=0.49.1
 | 
			
		||||
#test that import: test_numba_integration.py
 | 
			
		||||
#For numba issue see https://github.com/pytorch/pytorch/issues/51511
 | 
			
		||||
#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073
 | 
			
		||||
 | 
			
		||||
#numpy
 | 
			
		||||
#Description: Provides N-dimensional arrays and linear algebra
 | 
			
		||||
@ -309,7 +307,7 @@ pytest-cpp==2.3.0
 | 
			
		||||
#Pinned versions: 2.3.0
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
z3-solver==4.15.1.0 ; platform_machine != "s390x"
 | 
			
		||||
z3-solver==4.15.1.0
 | 
			
		||||
#Description: The Z3 Theorem Prover Project
 | 
			
		||||
#Pinned versions:
 | 
			
		||||
#test that import:
 | 
			
		||||
 | 
			
		||||
@ -138,11 +138,28 @@ fi
 | 
			
		||||
 | 
			
		||||
echo "Calling setup.py bdist at $(date)"
 | 
			
		||||
 | 
			
		||||
time CMAKE_ARGS=${CMAKE_ARGS[@]} \
 | 
			
		||||
    EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \
 | 
			
		||||
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
 | 
			
		||||
    echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)"
 | 
			
		||||
    time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \
 | 
			
		||||
    BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 \
 | 
			
		||||
    BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \
 | 
			
		||||
    USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \
 | 
			
		||||
    python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR
 | 
			
		||||
    echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)"
 | 
			
		||||
    echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)"
 | 
			
		||||
    time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \
 | 
			
		||||
    BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 \
 | 
			
		||||
    BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \
 | 
			
		||||
    USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \
 | 
			
		||||
    CMAKE_FRESH=1 python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR
 | 
			
		||||
    echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)"
 | 
			
		||||
else
 | 
			
		||||
    time CMAKE_ARGS=${CMAKE_ARGS[@]} \
 | 
			
		||||
        EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \
 | 
			
		||||
        BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \
 | 
			
		||||
        USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \
 | 
			
		||||
        python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR
 | 
			
		||||
fi
 | 
			
		||||
echo "Finished setup.py bdist at $(date)"
 | 
			
		||||
 | 
			
		||||
# Build libtorch packages
 | 
			
		||||
@ -255,6 +272,10 @@ ls /tmp/$WHEELHOUSE_DIR
 | 
			
		||||
mkdir -p "/$WHEELHOUSE_DIR"
 | 
			
		||||
mv /tmp/$WHEELHOUSE_DIR/torch*linux*.whl /$WHEELHOUSE_DIR/
 | 
			
		||||
 | 
			
		||||
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
 | 
			
		||||
    mv /tmp/$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/ || true
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [[ -n "$BUILD_PYTHONLESS" ]]; then
 | 
			
		||||
    mkdir -p /$LIBTORCH_HOUSE_DIR
 | 
			
		||||
    mv /tmp/$LIBTORCH_HOUSE_DIR/*.zip /$LIBTORCH_HOUSE_DIR
 | 
			
		||||
@ -431,8 +452,16 @@ if [[ -z "$BUILD_PYTHONLESS" ]]; then
 | 
			
		||||
  pushd $PYTORCH_ROOT/test
 | 
			
		||||
 | 
			
		||||
  # Install the wheel for this Python version
 | 
			
		||||
  if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
 | 
			
		||||
    pip uninstall -y "$TORCH_NO_PYTHON_PACKAGE_NAME" || true
 | 
			
		||||
  fi
 | 
			
		||||
 | 
			
		||||
  pip uninstall -y "$TORCH_PACKAGE_NAME"
 | 
			
		||||
 | 
			
		||||
  if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
 | 
			
		||||
    pip install "$TORCH_NO_PYTHON_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v
 | 
			
		||||
  fi
 | 
			
		||||
 | 
			
		||||
  pip install "$TORCH_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v
 | 
			
		||||
 | 
			
		||||
  # Print info on the libraries installed in this wheel
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,9 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then
 | 
			
		||||
  export ATEN_THREADING=NATIVE
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# Enable LLVM dependency for TensorExpr testing
 | 
			
		||||
export USE_LLVM=/opt/llvm
 | 
			
		||||
export LLVM_DIR=/opt/llvm/lib/cmake/llvm
 | 
			
		||||
 | 
			
		||||
if ! which conda; then
 | 
			
		||||
  # In ROCm CIs, we are doing cross compilation on build machines with
 | 
			
		||||
@ -173,7 +176,7 @@ fi
 | 
			
		||||
 | 
			
		||||
# We only build FlashAttention files for CUDA 8.0+, and they require large amounts of
 | 
			
		||||
# memory to build and will OOM
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && echo "${TORCH_CUDA_ARCH_LIST}" | tr ' ' '\n' | sed 's/$/>= 8.0/' | bc | grep -q 1; then
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]]; then
 | 
			
		||||
  export BUILD_CUSTOM_STEP="ninja -C build flash_attention -j 2"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
@ -189,6 +192,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then
 | 
			
		||||
  export USE_ASAN=1
 | 
			
		||||
  export REL_WITH_DEB_INFO=1
 | 
			
		||||
  export UBSAN_FLAGS="-fno-sanitize-recover=all"
 | 
			
		||||
  unset USE_LLVM
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then
 | 
			
		||||
@ -261,13 +265,22 @@ else
 | 
			
		||||
 | 
			
		||||
      WERROR=1 python setup.py clean
 | 
			
		||||
 | 
			
		||||
      WERROR=1 python setup.py bdist_wheel
 | 
			
		||||
      if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
 | 
			
		||||
        python3 tools/packaging/split_wheel.py bdist_wheel
 | 
			
		||||
      else
 | 
			
		||||
        WERROR=1 python setup.py bdist_wheel
 | 
			
		||||
      fi
 | 
			
		||||
    else
 | 
			
		||||
      python setup.py clean
 | 
			
		||||
      if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then
 | 
			
		||||
        source .ci/pytorch/install_cache_xla.sh
 | 
			
		||||
      fi
 | 
			
		||||
      python setup.py bdist_wheel
 | 
			
		||||
      if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
 | 
			
		||||
        echo "USE_SPLIT_BUILD cannot be used with xla or rocm"
 | 
			
		||||
        exit 1
 | 
			
		||||
      else
 | 
			
		||||
        python setup.py bdist_wheel
 | 
			
		||||
      fi
 | 
			
		||||
    fi
 | 
			
		||||
    pip_install_whl "$(echo dist/*.whl)"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -175,9 +175,6 @@ checkout_install_torchbench() {
 | 
			
		||||
    python install.py --continue_on_fail
 | 
			
		||||
  fi
 | 
			
		||||
 | 
			
		||||
  # soxr comes from https://github.com/huggingface/transformers/pull/39429
 | 
			
		||||
  pip install transformers==4.54.0 soxr==0.5.0
 | 
			
		||||
 | 
			
		||||
  echo "Print all dependencies after TorchBench is installed"
 | 
			
		||||
  python -mpip freeze
 | 
			
		||||
  popd
 | 
			
		||||
 | 
			
		||||
@ -1051,10 +1051,20 @@ test_libtorch_api() {
 | 
			
		||||
    mkdir -p $TEST_REPORTS_DIR
 | 
			
		||||
 | 
			
		||||
    OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml
 | 
			
		||||
    "$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml
 | 
			
		||||
  else
 | 
			
		||||
    # Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy
 | 
			
		||||
    OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest"
 | 
			
		||||
 | 
			
		||||
    # On s390x, pytorch is built without llvm.
 | 
			
		||||
    # Even if it would be built with llvm, llvm currently doesn't support used features on s390x and
 | 
			
		||||
    # test fails with errors like:
 | 
			
		||||
    # JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer
 | 
			
		||||
    # unknown file: Failure
 | 
			
		||||
    # C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) }
 | 
			
		||||
    if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
 | 
			
		||||
      python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
 | 
			
		||||
    fi
 | 
			
		||||
  fi
 | 
			
		||||
 | 
			
		||||
  # quantization is not fully supported on s390x yet
 | 
			
		||||
@ -1682,6 +1692,7 @@ elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then
 | 
			
		||||
elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
 | 
			
		||||
  install_torchaudio
 | 
			
		||||
  install_torchvision
 | 
			
		||||
  install_torchao
 | 
			
		||||
  id=$((SHARD_NUMBER-1))
 | 
			
		||||
  # https://github.com/opencv/opencv-python/issues/885
 | 
			
		||||
  pip_install opencv-python==4.8.0.74
 | 
			
		||||
 | 
			
		||||
@ -61,10 +61,9 @@ if "%USE_XPU%"=="1" (
 | 
			
		||||
  call "C:\Program Files (x86)\Intel\oneAPI\compiler\latest\env\vars.bat"
 | 
			
		||||
  call "C:\Program Files (x86)\Intel\oneAPI\ocloc\latest\env\vars.bat"
 | 
			
		||||
  if errorlevel 1 exit /b 1
 | 
			
		||||
  :: Reduce build time
 | 
			
		||||
  SET TORCH_XPU_ARCH_LIST=bmg
 | 
			
		||||
  :: Re-setup python env for build
 | 
			
		||||
  call pip install -r requirements.txt
 | 
			
		||||
  :: Reduce build time. Only have MTL self-hosted runner now
 | 
			
		||||
  SET TORCH_XPU_ARCH_LIST=xe-lpg
 | 
			
		||||
  SET USE_KINETO=0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@echo on
 | 
			
		||||
 | 
			
		||||
@ -192,6 +192,9 @@ retry brew install libomp
 | 
			
		||||
# For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule
 | 
			
		||||
export USE_DISTRIBUTED=1
 | 
			
		||||
 | 
			
		||||
if [[ -n "$CROSS_COMPILE_ARM64" ]]; then
 | 
			
		||||
    export CMAKE_OSX_ARCHITECTURES=arm64
 | 
			
		||||
fi
 | 
			
		||||
export USE_MKLDNN=OFF
 | 
			
		||||
export USE_QNNPACK=OFF
 | 
			
		||||
export BUILD_TEST=OFF
 | 
			
		||||
@ -199,7 +202,16 @@ export BUILD_TEST=OFF
 | 
			
		||||
pushd "$pytorch_rootdir"
 | 
			
		||||
echo "Calling setup.py bdist_wheel at $(date)"
 | 
			
		||||
 | 
			
		||||
python setup.py bdist_wheel -d "$whl_tmp_dir"
 | 
			
		||||
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
 | 
			
		||||
    echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)"
 | 
			
		||||
    BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel -d "$whl_tmp_dir"
 | 
			
		||||
    echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)"
 | 
			
		||||
    echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)"
 | 
			
		||||
    BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 CMAKE_FRESH=1 python setup.py bdist_wheel -d "$whl_tmp_dir"
 | 
			
		||||
    echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)"
 | 
			
		||||
else
 | 
			
		||||
    python setup.py bdist_wheel -d "$whl_tmp_dir"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
echo "Finished setup.py bdist_wheel at $(date)"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -65,8 +65,16 @@ fi
 | 
			
		||||
 | 
			
		||||
if [[ "$PACKAGE_TYPE" != libtorch ]]; then
 | 
			
		||||
  if [[ "\$BUILD_ENVIRONMENT" != *s390x* ]]; then
 | 
			
		||||
    pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}"
 | 
			
		||||
    retry pip install -q numpy protobuf typing-extensions
 | 
			
		||||
    if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
 | 
			
		||||
      pkg_no_python="$(ls -1 /final_pkgs/torch_no_python* | sort |tail -1)"
 | 
			
		||||
      pkg_torch="$(ls -1 /final_pkgs/torch-* | sort |tail -1)"
 | 
			
		||||
      # todo: after folder is populated use the pypi_pkg channel instead
 | 
			
		||||
      pip install "\$pkg_no_python" "\$pkg_torch" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}_pypi_pkg"
 | 
			
		||||
      retry pip install -q numpy protobuf typing-extensions
 | 
			
		||||
    else
 | 
			
		||||
      pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}"
 | 
			
		||||
      retry pip install -q numpy protobuf typing-extensions
 | 
			
		||||
    fi
 | 
			
		||||
  else
 | 
			
		||||
    pip install "\$pkg"
 | 
			
		||||
    retry pip install -q numpy protobuf typing-extensions
 | 
			
		||||
 | 
			
		||||
@ -134,6 +134,7 @@ export DESIRED_PYTHON="${DESIRED_PYTHON:-}"
 | 
			
		||||
export DESIRED_CUDA="$DESIRED_CUDA"
 | 
			
		||||
export LIBTORCH_VARIANT="${LIBTORCH_VARIANT:-}"
 | 
			
		||||
export BUILD_PYTHONLESS="${BUILD_PYTHONLESS:-}"
 | 
			
		||||
export USE_SPLIT_BUILD="${USE_SPLIT_BUILD:-}"
 | 
			
		||||
if [[ "${OSTYPE}" == "msys" ]]; then
 | 
			
		||||
  export LIBTORCH_CONFIG="${LIBTORCH_CONFIG:-}"
 | 
			
		||||
  if [[ "${LIBTORCH_CONFIG:-}" == 'debug' ]]; then
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,10 @@ if [[ "${DRY_RUN}" = "disabled" ]]; then
 | 
			
		||||
  AWS_S3_CP="aws s3 cp"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [[ "${USE_SPLIT_BUILD:-false}" == "true" ]]; then
 | 
			
		||||
  UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_pypi_pkg"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# this is special build with all dependencies packaged
 | 
			
		||||
if [[ ${BUILD_NAME} == *-full* ]]; then
 | 
			
		||||
  UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_full"
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,7 @@ runs:
 | 
			
		||||
          -e PYTORCH_FINAL_PACKAGE_DIR \
 | 
			
		||||
          -e PYTORCH_ROOT \
 | 
			
		||||
          -e SKIP_ALL_TESTS \
 | 
			
		||||
          -e USE_SPLIT_BUILD \
 | 
			
		||||
          --tty \
 | 
			
		||||
          --detach \
 | 
			
		||||
          -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
bdb88e1d66f272cad72156c90ac8428ca61a601c
 | 
			
		||||
6fbc710b617f79b992ef2ebc7f95e818aa390293
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
458e74eb907f96069e6d8a4f3c9f457001fef2ea
 | 
			
		||||
6a39ba85fe0f2fff9494b5eccea717c93510c230
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/xla.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
095faec1e7b6cc47220181e74ae9cde2605f9b00
 | 
			
		||||
b6a5b82b9948b610fa4c304d0d869c82b8f17db1
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										13
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										13
									
								
								.github/scripts/generate_binary_build_matrix.py
									
									
									
									
										vendored
									
									
								
							@ -273,6 +273,7 @@ def generate_wheels_matrix(
 | 
			
		||||
    os: str,
 | 
			
		||||
    arches: Optional[list[str]] = None,
 | 
			
		||||
    python_versions: Optional[list[str]] = None,
 | 
			
		||||
    use_split_build: bool = False,
 | 
			
		||||
) -> list[dict[str, str]]:
 | 
			
		||||
    package_type = "wheel"
 | 
			
		||||
    if os == "linux" or os == "linux-aarch64" or os == "linux-s390x":
 | 
			
		||||
@ -320,6 +321,15 @@ def generate_wheels_matrix(
 | 
			
		||||
            ):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            if use_split_build and (
 | 
			
		||||
                arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux"
 | 
			
		||||
            ):
 | 
			
		||||
                raise RuntimeError(
 | 
			
		||||
                    "Split build is only supported on linux with cuda 12* and cpu.\n"
 | 
			
		||||
                    f"Currently attempting to build on arch version {arch_version} and os {os}.\n"
 | 
			
		||||
                    "Please modify the matrix generation to exclude this combination."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # cuda linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
@ -334,6 +344,7 @@ def generate_wheels_matrix(
 | 
			
		||||
                        "gpu_arch_type": gpu_arch_type,
 | 
			
		||||
                        "gpu_arch_version": gpu_arch_version,
 | 
			
		||||
                        "desired_cuda": desired_cuda,
 | 
			
		||||
                        "use_split_build": "True" if use_split_build else "False",
 | 
			
		||||
                        "container_image": WHEEL_CONTAINER_IMAGES[arch_version].split(
 | 
			
		||||
                            ":"
 | 
			
		||||
                        )[0],
 | 
			
		||||
@ -366,6 +377,7 @@ def generate_wheels_matrix(
 | 
			
		||||
                            "desired_cuda": translate_desired_cuda(
 | 
			
		||||
                                gpu_arch_type, gpu_arch_version
 | 
			
		||||
                            ),
 | 
			
		||||
                            "use_split_build": "True" if use_split_build else "False",
 | 
			
		||||
                            "container_image": WHEEL_CONTAINER_IMAGES[
 | 
			
		||||
                                arch_version
 | 
			
		||||
                            ].split(":")[0],
 | 
			
		||||
@ -388,6 +400,7 @@ def generate_wheels_matrix(
 | 
			
		||||
                        "desired_cuda": translate_desired_cuda(
 | 
			
		||||
                            gpu_arch_type, gpu_arch_version
 | 
			
		||||
                        ),
 | 
			
		||||
                        "use_split_build": "True" if use_split_build else "False",
 | 
			
		||||
                        "container_image": WHEEL_CONTAINER_IMAGES[arch_version].split(
 | 
			
		||||
                            ":"
 | 
			
		||||
                        )[0],
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										42
									
								
								.github/scripts/generate_ci_workflows.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										42
									
								
								.github/scripts/generate_ci_workflows.py
									
									
									
									
										vendored
									
									
								
							@ -59,7 +59,9 @@ class BinaryBuildWorkflow:
 | 
			
		||||
    is_scheduled: str = ""
 | 
			
		||||
    branches: str = "nightly"
 | 
			
		||||
    # Mainly for macos
 | 
			
		||||
    cross_compile_arm64: bool = False
 | 
			
		||||
    macos_runner: str = "macos-14-xlarge"
 | 
			
		||||
    use_split_build: bool = False
 | 
			
		||||
    # Mainly used for libtorch builds
 | 
			
		||||
    build_variant: str = ""
 | 
			
		||||
 | 
			
		||||
@ -70,6 +72,9 @@ class BinaryBuildWorkflow:
 | 
			
		||||
                for item in [self.os, "binary", self.package_type, self.build_variant]
 | 
			
		||||
                if item != ""
 | 
			
		||||
            )
 | 
			
		||||
        if self.use_split_build:
 | 
			
		||||
            # added to distinguish concurrency groups
 | 
			
		||||
            self.build_environment += "-split"
 | 
			
		||||
 | 
			
		||||
    def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
 | 
			
		||||
        output_file_path = (
 | 
			
		||||
@ -112,6 +117,21 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
 | 
			
		||||
            isolated_workflow=True,
 | 
			
		||||
        ),
 | 
			
		||||
    ),
 | 
			
		||||
    # See https://github.com/pytorch/pytorch/issues/138750
 | 
			
		||||
    #   BinaryBuildWorkflow(
 | 
			
		||||
    #     os=OperatingSystem.LINUX,
 | 
			
		||||
    #     package_type="manywheel",
 | 
			
		||||
    #     build_configs=generate_binary_build_matrix.generate_wheels_matrix(
 | 
			
		||||
    #         OperatingSystem.LINUX,
 | 
			
		||||
    #         use_split_build=True,
 | 
			
		||||
    #         arches=["11.8", "12.1", "12.4", "cpu"],
 | 
			
		||||
    #     ),
 | 
			
		||||
    #     ciflow_config=CIFlowConfig(
 | 
			
		||||
    #         labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
 | 
			
		||||
    #         isolated_workflow=True,
 | 
			
		||||
    #     ),
 | 
			
		||||
    #     use_split_build=True,
 | 
			
		||||
    # ),
 | 
			
		||||
    BinaryBuildWorkflow(
 | 
			
		||||
        os=OperatingSystem.LINUX,
 | 
			
		||||
        package_type="libtorch",
 | 
			
		||||
@ -155,11 +175,27 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
 | 
			
		||||
        package_type="manywheel",
 | 
			
		||||
        build_configs=generate_binary_build_matrix.generate_wheels_matrix(
 | 
			
		||||
            OperatingSystem.LINUX,
 | 
			
		||||
            arches=["12.8"],
 | 
			
		||||
            python_versions=["3.12"],
 | 
			
		||||
            arches=["12.6", "12.8", "12.9"],
 | 
			
		||||
            python_versions=["3.9"],
 | 
			
		||||
        ),
 | 
			
		||||
        branches="main",
 | 
			
		||||
    ),
 | 
			
		||||
    # See https://github.com/pytorch/pytorch/issues/138750
 | 
			
		||||
    # BinaryBuildWorkflow(
 | 
			
		||||
    #     os=OperatingSystem.LINUX,
 | 
			
		||||
    #     package_type="manywheel",
 | 
			
		||||
    #     build_configs=generate_binary_build_matrix.generate_wheels_matrix(
 | 
			
		||||
    #         OperatingSystem.LINUX,
 | 
			
		||||
    #         arches=["11.8", "12.1", "12.4"],
 | 
			
		||||
    #         python_versions=["3.9"],
 | 
			
		||||
    #         use_split_build=True,
 | 
			
		||||
    #     ),
 | 
			
		||||
    #     ciflow_config=CIFlowConfig(
 | 
			
		||||
    #         labels={LABEL_CIFLOW_PERIODIC},
 | 
			
		||||
    #     ),
 | 
			
		||||
    #     branches="main",
 | 
			
		||||
    #     use_split_build=True,
 | 
			
		||||
    # ),
 | 
			
		||||
    BinaryBuildWorkflow(
 | 
			
		||||
        os=OperatingSystem.LINUX,
 | 
			
		||||
        package_type="libtorch",
 | 
			
		||||
@ -302,6 +338,7 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
 | 
			
		||||
            generate_binary_build_matrix.RELEASE,
 | 
			
		||||
            libtorch_variants=["shared-with-deps"],
 | 
			
		||||
        ),
 | 
			
		||||
        cross_compile_arm64=False,
 | 
			
		||||
        macos_runner="macos-14-xlarge",
 | 
			
		||||
        ciflow_config=CIFlowConfig(
 | 
			
		||||
            labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_LIBTORCH},
 | 
			
		||||
@ -314,6 +351,7 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
 | 
			
		||||
        build_configs=generate_binary_build_matrix.generate_wheels_matrix(
 | 
			
		||||
            OperatingSystem.MACOS_ARM64
 | 
			
		||||
        ),
 | 
			
		||||
        cross_compile_arm64=False,
 | 
			
		||||
        macos_runner="macos-14-xlarge",
 | 
			
		||||
        ciflow_config=CIFlowConfig(
 | 
			
		||||
            labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								.github/scripts/runner_determinator.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.github/scripts/runner_determinator.py
									
									
									
									
										vendored
									
									
								
							@ -262,12 +262,7 @@ def is_exception_branch(branch: str) -> bool:
 | 
			
		||||
    """
 | 
			
		||||
    Branches that get opted out of experiments by default, until they're explicitly enabled.
 | 
			
		||||
    """
 | 
			
		||||
    return branch.split("/", maxsplit=1)[0] in {
 | 
			
		||||
        "main",
 | 
			
		||||
        "nightly",
 | 
			
		||||
        "release",
 | 
			
		||||
        "landchecks",
 | 
			
		||||
    }
 | 
			
		||||
    return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_yaml(yaml_text: str) -> Any:
 | 
			
		||||
 | 
			
		||||
@ -47,6 +47,9 @@ env:
 | 
			
		||||
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  PR_NUMBER: ${{ github.event.pull_request.number }}
 | 
			
		||||
  SKIP_ALL_TESTS: 0
 | 
			
		||||
{%- if cross_compile_arm64 %}
 | 
			
		||||
  CROSS_COMPILE_ARM64: 1
 | 
			
		||||
{% endif %}
 | 
			
		||||
!{{ common.concurrency(build_environment) }}
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/templates/upload.yml.j2
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/templates/upload.yml.j2
									
									
									
									
										vendored
									
									
								
							@ -25,6 +25,11 @@
 | 
			
		||||
      DOCKER_IMAGE: !{{ config["container_image"] }}
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: !{{ config["container_image_tag_prefix"] }}
 | 
			
		||||
{%- endif %}
 | 
			
		||||
{%- if config["package_type"] == "manywheel" %}
 | 
			
		||||
  {%- if config.use_split_build is defined %}
 | 
			
		||||
      use_split_build: !{{ config["use_split_build"] }}
 | 
			
		||||
  {%- endif %}
 | 
			
		||||
{%- endif %}
 | 
			
		||||
{%- if config["package_type"] == "libtorch" %}
 | 
			
		||||
  {%- if config["libtorch_config"] %}
 | 
			
		||||
      LIBTORCH_CONFIG: !{{ config["libtorch_config"] }}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								.github/workflows/_binary-build-linux.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/_binary-build-linux.yml
									
									
									
									
										vendored
									
									
								
							@ -26,6 +26,13 @@ on:
 | 
			
		||||
        default: 240
 | 
			
		||||
        type: number
 | 
			
		||||
        description: timeout for the job
 | 
			
		||||
      use_split_build:
 | 
			
		||||
        description: |
 | 
			
		||||
          [Experimental] Build a libtorch only wheel and build pytorch such that
 | 
			
		||||
          are built from the libtorch wheel.
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
      ALPINE_IMAGE:
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
@ -110,6 +117,7 @@ jobs:
 | 
			
		||||
      PR_NUMBER: ${{ github.event.pull_request.number }}
 | 
			
		||||
      PYTORCH_FINAL_PACKAGE_DIR: /artifacts
 | 
			
		||||
      SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
      USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Make the env permanent during this workflow (but not the secrets)
 | 
			
		||||
        shell: bash
 | 
			
		||||
@ -134,6 +142,7 @@ jobs:
 | 
			
		||||
            echo "PR_NUMBER=${{ env.PR_NUMBER }}"
 | 
			
		||||
            echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
            echo "SHA1=${{ env.SHA1 }}"
 | 
			
		||||
            echo "USE_SPLIT_BUILD=${{ env.use_split_build }}"
 | 
			
		||||
          } >> "${GITHUB_ENV} }}"
 | 
			
		||||
 | 
			
		||||
      - name: List the env
 | 
			
		||||
@ -252,6 +261,7 @@ jobs:
 | 
			
		||||
            -e PYTORCH_ROOT \
 | 
			
		||||
            -e SKIP_ALL_TESTS \
 | 
			
		||||
            -e PYTORCH_EXTRA_INSTALL_REQUIREMENTS \
 | 
			
		||||
            -e USE_SPLIT_BUILD \
 | 
			
		||||
            --tty \
 | 
			
		||||
            --detach \
 | 
			
		||||
            -v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										9
									
								
								.github/workflows/_binary-test-linux.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										9
									
								
								.github/workflows/_binary-test-linux.yml
									
									
									
									
										vendored
									
									
								
							@ -64,6 +64,13 @@ on:
 | 
			
		||||
        required: true
 | 
			
		||||
        type: string
 | 
			
		||||
        description: Hardware to run this job on. Valid values are linux.4xlarge, linux.4xlarge.nvidia.gpu, linux.arm64.2xlarge, and linux.rocm.gpu
 | 
			
		||||
      use_split_build:
 | 
			
		||||
        description: |
 | 
			
		||||
          [Experimental] Build a libtorch only wheel and build pytorch such that
 | 
			
		||||
          are built from the libtorch wheel.
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token:
 | 
			
		||||
        required: true
 | 
			
		||||
@ -97,6 +104,7 @@ jobs:
 | 
			
		||||
      PR_NUMBER: ${{ github.event.pull_request.number }}
 | 
			
		||||
      PYTORCH_FINAL_PACKAGE_DIR: /artifacts
 | 
			
		||||
      SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
      USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Make the env permanent during this workflow (but not the secrets)
 | 
			
		||||
        shell: bash
 | 
			
		||||
@ -121,6 +129,7 @@ jobs:
 | 
			
		||||
            echo "PR_NUMBER=${{ env.PR_NUMBER }}"
 | 
			
		||||
            echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
            echo "SHA1=${{ env.SHA1 }}"
 | 
			
		||||
            echo "USE_SPLIT_BUILD=${{ env.USE_SPLIT_BUILD }}"
 | 
			
		||||
          } >> "${GITHUB_ENV} }}"
 | 
			
		||||
 | 
			
		||||
      - name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										8
									
								
								.github/workflows/_binary-upload.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/_binary-upload.yml
									
									
									
									
										vendored
									
									
								
							@ -51,6 +51,13 @@ on:
 | 
			
		||||
        required: false
 | 
			
		||||
        type: string
 | 
			
		||||
        description: Desired python version
 | 
			
		||||
      use_split_build:
 | 
			
		||||
        description: |
 | 
			
		||||
          [Experimental] Build a libtorch only wheel and build pytorch such that
 | 
			
		||||
          are built from the libtorch wheel.
 | 
			
		||||
        required: false
 | 
			
		||||
        type: boolean
 | 
			
		||||
        default: false
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token:
 | 
			
		||||
        required: true
 | 
			
		||||
@ -79,6 +86,7 @@ jobs:
 | 
			
		||||
      PR_NUMBER: ${{ github.event.pull_request.number }}
 | 
			
		||||
      PYTORCH_FINAL_PACKAGE_DIR: /artifacts
 | 
			
		||||
      SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
      USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							@ -306,6 +306,7 @@ jobs:
 | 
			
		||||
            -e OUR_GITHUB_JOB_ID \
 | 
			
		||||
            -e HUGGING_FACE_HUB_TOKEN \
 | 
			
		||||
            -e SCRIBE_GRAPHQL_ACCESS_TOKEN \
 | 
			
		||||
            -e USE_SPLIT_BUILD \
 | 
			
		||||
            -e BUILD_ADDITIONAL_PACKAGES \
 | 
			
		||||
            --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \
 | 
			
		||||
            --memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/docker-builds.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/docker-builds.yml
									
									
									
									
										vendored
									
									
								
							@ -61,7 +61,6 @@ jobs:
 | 
			
		||||
          pytorch-linux-jammy-rocm-n-py3,
 | 
			
		||||
          pytorch-linux-noble-rocm-n-py3,
 | 
			
		||||
          pytorch-linux-noble-rocm-alpha-py3,
 | 
			
		||||
          pytorch-linux-jammy-rocm-n-py3-benchmarks,
 | 
			
		||||
          pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12,
 | 
			
		||||
          pytorch-linux-jammy-py3.9-gcc11,
 | 
			
		||||
          pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										30
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -60,6 +60,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -83,6 +84,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      build_name: manywheel-py3_9-cpu-aarch64
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
@ -106,6 +108,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      build_name: manywheel-py3_9-cpu-aarch64
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -126,6 +129,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -152,6 +156,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      build_name: manywheel-py3_9-cuda-aarch64-12_9
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -171,6 +176,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -194,6 +200,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: manywheel-py3_10-cpu-aarch64
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
@ -217,6 +224,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: manywheel-py3_10-cpu-aarch64
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -237,6 +245,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -263,6 +272,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: manywheel-py3_10-cuda-aarch64-12_9
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -282,6 +292,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.11"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -305,6 +316,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.11"
 | 
			
		||||
      build_name: manywheel-py3_11-cpu-aarch64
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
@ -328,6 +340,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.11"
 | 
			
		||||
      build_name: manywheel-py3_11-cpu-aarch64
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -348,6 +361,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.11"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -374,6 +388,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.11"
 | 
			
		||||
      build_name: manywheel-py3_11-cuda-aarch64-12_9
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -393,6 +408,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -416,6 +432,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: manywheel-py3_12-cpu-aarch64
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
@ -439,6 +456,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: manywheel-py3_12-cpu-aarch64
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -459,6 +477,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -485,6 +504,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda-aarch64-12_9
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -504,6 +524,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -527,6 +548,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13"
 | 
			
		||||
      build_name: manywheel-py3_13-cpu-aarch64
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
@ -550,6 +572,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13"
 | 
			
		||||
      build_name: manywheel-py3_13-cpu-aarch64
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -570,6 +593,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -596,6 +620,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13"
 | 
			
		||||
      build_name: manywheel-py3_13-cuda-aarch64-12_9
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -615,6 +640,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13t"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -638,6 +664,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13t"
 | 
			
		||||
      build_name: manywheel-py3_13t-cpu-aarch64
 | 
			
		||||
      build_environment: linux-aarch64-binary-manywheel
 | 
			
		||||
@ -661,6 +688,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28_aarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13t"
 | 
			
		||||
      build_name: manywheel-py3_13t-cpu-aarch64
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -681,6 +709,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13t"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.arm64.m7g.4xlarge.ephemeral
 | 
			
		||||
@ -707,6 +736,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda-aarch64
 | 
			
		||||
      DOCKER_IMAGE: manylinuxaarch64-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13t"
 | 
			
		||||
      build_name: manywheel-py3_13t-cuda-aarch64-12_9
 | 
			
		||||
    secrets:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										110
									
								
								.github/workflows/generated-linux-binary-manywheel-main.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										110
									
								
								.github/workflows/generated-linux-binary-manywheel-main.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -42,7 +42,54 @@ jobs:
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
  manywheel-py3_12-cuda12_8-build:
 | 
			
		||||
  manywheel-py3_9-cuda12_6-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: manywheel
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu126
 | 
			
		||||
      GPU_ARCH_VERSION: 12.6
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.6
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_9-cuda12_6
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_9-cuda12_6-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - manywheel-py3_9-cuda12_6-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    uses: ./.github/workflows/_binary-test-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: manywheel
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu126
 | 
			
		||||
      GPU_ARCH_VERSION: 12.6
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.6
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      build_name: manywheel-py3_9-cuda12_6
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
  manywheel-py3_9-cuda12_8-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
@ -56,17 +103,18 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.8
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda12_8
 | 
			
		||||
      build_name: manywheel-py3_9-cuda12_8
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_12-cuda12_8-test:  # Testing
 | 
			
		||||
  manywheel-py3_9-cuda12_8-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - manywheel-py3_12-cuda12_8-build
 | 
			
		||||
      - manywheel-py3_9-cuda12_8-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    uses: ./.github/workflows/_binary-test-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
@ -79,8 +127,56 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.8
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: manywheel-py3_12-cuda12_8
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      build_name: manywheel-py3_9-cuda12_8
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.g4dn.4xlarge.nvidia.gpu  # 12.8 and 12.9 build need sm_70+ runner
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
 | 
			
		||||
  manywheel-py3_9-cuda12_9-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-build-linux.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: manywheel
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: 12.9
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_9-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  manywheel-py3_9-cuda12_9-test:  # Testing
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    needs:
 | 
			
		||||
      - manywheel-py3_9-cuda12_9-build
 | 
			
		||||
      - get-label-type
 | 
			
		||||
    uses: ./.github/workflows/_binary-test-linux.yml
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: manywheel
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cu129
 | 
			
		||||
      GPU_ARCH_VERSION: 12.9
 | 
			
		||||
      GPU_ARCH_TYPE: cuda
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cuda12.9
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      build_name: manywheel-py3_9-cuda12_9
 | 
			
		||||
      build_environment: linux-binary-manywheel
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      runs_on: linux.g4dn.4xlarge.nvidia.gpu  # 12.8 and 12.9 build need sm_70+ runner
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										171
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										171
									
								
								.github/workflows/generated-linux-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										2
									
								
								.github/workflows/generated-linux-binary-manywheel-rocm-main.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/generated-linux-binary-manywheel-rocm-main.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -58,6 +58,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: rocm
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: rocm6.4
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build_name: manywheel-py3_9-rocm6_4
 | 
			
		||||
@ -82,6 +83,7 @@ jobs:
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: rocm6.4
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Setup ROCm
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										15
									
								
								.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										15
									
								
								.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -60,6 +60,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      runs_on: linux.s390x
 | 
			
		||||
      ALPINE_IMAGE: "docker.io/s390x/alpine"
 | 
			
		||||
@ -83,6 +84,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      build_name: manywheel-py3_9-cpu-s390x
 | 
			
		||||
      build_environment: linux-s390x-binary-manywheel
 | 
			
		||||
@ -105,6 +107,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.9"
 | 
			
		||||
      build_name: manywheel-py3_9-cpu-s390x
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -124,6 +127,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      runs_on: linux.s390x
 | 
			
		||||
      ALPINE_IMAGE: "docker.io/s390x/alpine"
 | 
			
		||||
@ -147,6 +151,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: manywheel-py3_10-cpu-s390x
 | 
			
		||||
      build_environment: linux-s390x-binary-manywheel
 | 
			
		||||
@ -169,6 +174,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.10"
 | 
			
		||||
      build_name: manywheel-py3_10-cpu-s390x
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -188,6 +194,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.11"
 | 
			
		||||
      runs_on: linux.s390x
 | 
			
		||||
      ALPINE_IMAGE: "docker.io/s390x/alpine"
 | 
			
		||||
@ -211,6 +218,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.11"
 | 
			
		||||
      build_name: manywheel-py3_11-cpu-s390x
 | 
			
		||||
      build_environment: linux-s390x-binary-manywheel
 | 
			
		||||
@ -233,6 +241,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.11"
 | 
			
		||||
      build_name: manywheel-py3_11-cpu-s390x
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -252,6 +261,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      runs_on: linux.s390x
 | 
			
		||||
      ALPINE_IMAGE: "docker.io/s390x/alpine"
 | 
			
		||||
@ -275,6 +285,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: manywheel-py3_12-cpu-s390x
 | 
			
		||||
      build_environment: linux-s390x-binary-manywheel
 | 
			
		||||
@ -297,6 +308,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.12"
 | 
			
		||||
      build_name: manywheel-py3_12-cpu-s390x
 | 
			
		||||
    secrets:
 | 
			
		||||
@ -316,6 +328,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13"
 | 
			
		||||
      runs_on: linux.s390x
 | 
			
		||||
      ALPINE_IMAGE: "docker.io/s390x/alpine"
 | 
			
		||||
@ -339,6 +352,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13"
 | 
			
		||||
      build_name: manywheel-py3_13-cpu-s390x
 | 
			
		||||
      build_environment: linux-s390x-binary-manywheel
 | 
			
		||||
@ -361,6 +375,7 @@ jobs:
 | 
			
		||||
      GPU_ARCH_TYPE: cpu-s390x
 | 
			
		||||
      DOCKER_IMAGE: pytorch/manylinuxs390x-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
 | 
			
		||||
      use_split_build: False
 | 
			
		||||
      DESIRED_PYTHON: "3.13"
 | 
			
		||||
      build_name: manywheel-py3_13-cpu-s390x
 | 
			
		||||
    secrets:
 | 
			
		||||
 | 
			
		||||
@ -85,7 +85,7 @@ jobs:
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -77,7 +77,7 @@ jobs:
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-rocm-py3_10
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										111
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										111
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							@ -254,6 +254,68 @@ jobs:
 | 
			
		||||
      timeout-minutes: 600
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-build-distributed:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
 | 
			
		||||
      cuda-arch-list: '7.5'
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
 | 
			
		||||
          { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
 | 
			
		||||
          { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-test-distributed:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-test
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-cuda12_8-py3_10-gcc11-build-distributed
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      timeout-minutes: 360
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-build:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
 | 
			
		||||
      cuda-arch-list: 8.9
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-test:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-cuda12_8-py3_10-gcc11-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      timeout-minutes: 360
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build:
 | 
			
		||||
    name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
@ -268,6 +330,30 @@ jobs:
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-py3_9-clang9-xla-build:
 | 
			
		||||
    name: linux-jammy-py3_9-clang9-xla
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-py3.9-clang9-xla
 | 
			
		||||
      docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-py3_9-clang9-xla-test:
 | 
			
		||||
    name: linux-jammy-py3_9-clang9-xla
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs: linux-jammy-py3_9-clang9-xla-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-py3.9-clang9-xla
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cpu-py3_10-gcc11-bazel-test:
 | 
			
		||||
    name: linux-jammy-cpu-py3.10-gcc11-bazel-test
 | 
			
		||||
    uses: ./.github/workflows/_bazel-build-test.yml
 | 
			
		||||
@ -343,6 +429,31 @@ jobs:
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc9-inductor-build:
 | 
			
		||||
    name: cuda12.8-py3.10-gcc9-sm75
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
 | 
			
		||||
      cuda-arch-list: '7.5'
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc9-inductor-test:
 | 
			
		||||
    name: cuda12.8-py3.10-gcc9-sm75
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-xpu-2025_1-py3_9-build:
 | 
			
		||||
    name: linux-jammy-xpu-2025.1-py3.9
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										37
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										37
									
								
								.github/workflows/trunk.yml
									
									
									
									
										vendored
									
									
								
							@ -63,43 +63,6 @@ jobs:
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-build:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
 | 
			
		||||
      cuda-arch-list: '7.5 8.9'
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
 | 
			
		||||
          { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
 | 
			
		||||
          { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
 | 
			
		||||
          { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
 | 
			
		||||
          { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-test:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs:
 | 
			
		||||
      - linux-jammy-cuda12_8-py3_10-gcc11-build
 | 
			
		||||
      - target-determination
 | 
			
		||||
    with:
 | 
			
		||||
      timeout-minutes: 360
 | 
			
		||||
      build-environment: linux-jammy-cuda12.8-py3.10-gcc11
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  # no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated
 | 
			
		||||
  linux-jammy-cuda12_8-py3_10-gcc11-no-ops-build:
 | 
			
		||||
    name: linux-jammy-cuda12.8-py3.10-gcc11-no-ops
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										28
									
								
								.github/workflows/unstable.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										28
									
								
								.github/workflows/unstable.yml
									
									
									
									
										vendored
									
									
								
							@ -12,9 +12,7 @@ concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions:
 | 
			
		||||
  id-token: write
 | 
			
		||||
  contents: read
 | 
			
		||||
permissions: read-all
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  # There must be at least one job here to satisfy GitHub action workflow syntax
 | 
			
		||||
@ -53,27 +51,3 @@ jobs:
 | 
			
		||||
      issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
 | 
			
		||||
      curr_branch: ${{ github.head_ref || github.ref_name }}
 | 
			
		||||
      curr_ref_type: ${{ github.ref_type }}
 | 
			
		||||
 | 
			
		||||
  linux-jammy-py3_9-clang9-xla-build:
 | 
			
		||||
    name: linux-jammy-py3_9-clang9-xla
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    needs: get-label-type
 | 
			
		||||
    with:
 | 
			
		||||
      runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
 | 
			
		||||
      build-environment: linux-jammy-py3.9-clang9-xla
 | 
			
		||||
      docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-py3_9-clang9-xla-test:
 | 
			
		||||
    name: linux-jammy-py3_9-clang9-xla
 | 
			
		||||
    uses: ./.github/workflows/_linux-test.yml
 | 
			
		||||
    needs: linux-jammy-py3_9-clang9-xla-build
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-jammy-py3.9-clang9-xla
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -146,9 +146,6 @@ merge_record.json
 | 
			
		||||
torchgen/packaged/*
 | 
			
		||||
!torchgen/packaged/README.md
 | 
			
		||||
 | 
			
		||||
# This file is injected by ROCm build scripts to bootstrap in torch/__init__.py.
 | 
			
		||||
torch/_rocm_init.py
 | 
			
		||||
 | 
			
		||||
# IPython notebook checkpoints
 | 
			
		||||
.ipynb_checkpoints
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1452,6 +1452,8 @@ init_command = [
 | 
			
		||||
    'python3',
 | 
			
		||||
    'tools/linter/adapters/pip_init.py',
 | 
			
		||||
    '--dry-run={{DRYRUN}}',
 | 
			
		||||
    '--no-black-binary',
 | 
			
		||||
    'black==23.12.1',
 | 
			
		||||
    'usort==1.0.8.post1',
 | 
			
		||||
    'isort==6.0.1',
 | 
			
		||||
    'ruff==0.12.2',  # sync with RUFF
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,12 @@
 | 
			
		||||
repos:
 | 
			
		||||
  - repo: local
 | 
			
		||||
    hooks:
 | 
			
		||||
      - id: lintrunner
 | 
			
		||||
        name: Run Lintrunner in an isolated venv before every push. The first run may be slow...
 | 
			
		||||
        entry: python scripts/run_lintrunner.py   # wrapper below
 | 
			
		||||
        language: python                          # pre‑commit manages venv for the wrapper
 | 
			
		||||
        additional_dependencies: []               # wrapper handles lintrunner install
 | 
			
		||||
        always_run: true
 | 
			
		||||
        stages: [pre-push]                        # fire only on pre‑push
 | 
			
		||||
        pass_filenames: false                     # Lintrunner gets no per‑file args
 | 
			
		||||
        verbose: true                             # stream output as it is produced...allegedly anyways
 | 
			
		||||
@ -8,7 +8,8 @@
 | 
			
		||||
  Instead run only a single test case, e.g., 'python test/test_torch.py TestTorch.test_dir'
 | 
			
		||||
- Do NOT run setup.py, you do not have a working build environment
 | 
			
		||||
- Do NOT run pre-commit, it is not setup
 | 
			
		||||
- To run lint, run 'lintrunner -a' (which will autoapply changes)
 | 
			
		||||
- To run lint, run 'lintrunner -a' (which will autoapply changes).  lintrunner
 | 
			
		||||
  ONLY accepts this flag, do not try to run on individual files.
 | 
			
		||||
- Do NOT attempt to install dependencies, you do not have Internet access
 | 
			
		||||
- When you are ready to make a PR, do exactly these steps:
 | 
			
		||||
  - git stash -u
 | 
			
		||||
 | 
			
		||||
@ -239,9 +239,7 @@ option(USE_XPU "Use XPU" ON)
 | 
			
		||||
cmake_dependent_option(
 | 
			
		||||
  BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
 | 
			
		||||
  "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
 | 
			
		||||
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX OR WIN32" OFF)
 | 
			
		||||
cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF)
 | 
			
		||||
option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF)
 | 
			
		||||
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
 | 
			
		||||
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
 | 
			
		||||
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
 | 
			
		||||
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
 | 
			
		||||
@ -253,6 +251,7 @@ cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF)
 | 
			
		||||
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
 | 
			
		||||
option(USE_KINETO "Use Kineto profiling library" ON)
 | 
			
		||||
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
 | 
			
		||||
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
 | 
			
		||||
option(USE_GFLAGS "Use GFLAGS" OFF)
 | 
			
		||||
option(USE_GLOG "Use GLOG" OFF)
 | 
			
		||||
option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
 | 
			
		||||
@ -261,13 +260,11 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF)
 | 
			
		||||
option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF)
 | 
			
		||||
option(USE_NATIVE_ARCH "Use -march=native" OFF)
 | 
			
		||||
cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF)
 | 
			
		||||
option(USE_DISTRIBUTED "Use distributed" ON)
 | 
			
		||||
cmake_dependent_option(USE_NCCL "Use NCCL" ON
 | 
			
		||||
                       "USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
 | 
			
		||||
                       "USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
 | 
			
		||||
cmake_dependent_option(USE_XCCL "Use XCCL" ON
 | 
			
		||||
                       "USE_XPU;UNIX;NOT APPLE" OFF)
 | 
			
		||||
cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF)
 | 
			
		||||
cmake_dependent_option(USE_RCCL "Use RCCL" ON "USE_NCCL;NOT WIN32" OFF)
 | 
			
		||||
cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF)
 | 
			
		||||
cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL"
 | 
			
		||||
                       OFF)
 | 
			
		||||
@ -325,6 +322,7 @@ set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN})
 | 
			
		||||
cmake_dependent_option(USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF "USE_MKLDNN"
 | 
			
		||||
                       OFF)
 | 
			
		||||
option(USE_STATIC_MKL "Prefer to link with MKL statically (Unix only)" OFF)
 | 
			
		||||
option(USE_DISTRIBUTED "Use distributed" ON)
 | 
			
		||||
cmake_dependent_option(
 | 
			
		||||
  USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
 | 
			
		||||
  "USE_DISTRIBUTED" OFF)
 | 
			
		||||
@ -836,11 +834,10 @@ include(ExternalProject)
 | 
			
		||||
 | 
			
		||||
# ---[ Dependencies ---[ FBGEMM doesn't work on x86 32bit and
 | 
			
		||||
# CMAKE_SYSTEM_PROCESSOR thinks its 64bit
 | 
			
		||||
if(USE_FBGEMM AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
 | 
			
		||||
  message(WARNING
 | 
			
		||||
    "x64 operating system is required for FBGEMM. "
 | 
			
		||||
    "Not compiling with FBGEMM. "
 | 
			
		||||
    "Turn this warning off by USE_FBGEMM=OFF.")
 | 
			
		||||
if(USE_FBGEMM
 | 
			
		||||
   AND((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VOID_P EQUAL
 | 
			
		||||
                                                      4)
 | 
			
		||||
        OR CMAKE_SYSTEM_PROCESSOR STREQUAL "x86"))
 | 
			
		||||
  set(USE_FBGEMM OFF)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -164,7 +164,6 @@ caffe2/utils/hip @jeffdaily @jithunnair-amd
 | 
			
		||||
# torch.export
 | 
			
		||||
/torch/export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi
 | 
			
		||||
/torch/_export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi
 | 
			
		||||
/torch/_export/serde/schema.py @SherlockNoMad @zhxchen17
 | 
			
		||||
 | 
			
		||||
# Dynamic Shapes
 | 
			
		||||
/torch/fx/experimental/symbolic_shapes.py @bobrenjc93 @laithsakka
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README.md
									
									
									
									
									
								
							@ -1,4 +1,4 @@
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
--------------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
@ -72,7 +72,7 @@ Elaborating Further:
 | 
			
		||||
 | 
			
		||||
If you use NumPy, then you have used Tensors (a.k.a. ndarray).
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the
 | 
			
		||||
computation by a huge amount.
 | 
			
		||||
@ -99,7 +99,7 @@ from several research papers on this topic, as well as current and past work suc
 | 
			
		||||
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date.
 | 
			
		||||
You get the best of speed and flexibility for your crazy research.
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
### Python First
 | 
			
		||||
 | 
			
		||||
@ -243,7 +243,7 @@ git submodule update --init --recursive
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
conda install cmake ninja
 | 
			
		||||
# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section above
 | 
			
		||||
# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below
 | 
			
		||||
pip install -r requirements.txt
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -560,7 +560,7 @@ To learn more about making a contribution to Pytorch, please see our [Contributi
 | 
			
		||||
 | 
			
		||||
PyTorch is a community-driven project with several skillful engineers and researchers contributing to it.
 | 
			
		||||
 | 
			
		||||
PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), [Alban Desmaison](https://github.com/albanD), [Piotr Bialecki](https://github.com/ptrblck) and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means.
 | 
			
		||||
PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means.
 | 
			
		||||
A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jekbradbury), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). <!-- codespell:ignore -->
 | 
			
		||||
 | 
			
		||||
Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch.
 | 
			
		||||
 | 
			
		||||
@ -119,8 +119,6 @@ file(GLOB_RECURSE native_mps_cpp "native/mps/*.cpp")
 | 
			
		||||
file(GLOB_RECURSE native_mps_mm "native/mps/*.mm")
 | 
			
		||||
file(GLOB_RECURSE native_mps_metal "native/mps/*.metal")
 | 
			
		||||
file(GLOB_RECURSE native_mps_h "native/mps/*.h")
 | 
			
		||||
file(GLOB_RECURSE native_sparse_mps_mm "native/sparse/mps/*.mm")
 | 
			
		||||
file(GLOB_RECURSE native_mps_sparse_metal "native/sparse/mps/*.metal")
 | 
			
		||||
 | 
			
		||||
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
 | 
			
		||||
file(GLOB native_quantized_cpp
 | 
			
		||||
@ -180,27 +178,26 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a
 | 
			
		||||
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
 | 
			
		||||
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
 | 
			
		||||
if(USE_FLASH_ATTENTION)
 | 
			
		||||
  if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1")
 | 
			
		||||
    message(STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead")
 | 
			
		||||
    caffe2_update_option(USE_ROCM_CK_SDPA ON)
 | 
			
		||||
  endif()
 | 
			
		||||
  if(USE_ROCM_CK_SDPA)
 | 
			
		||||
    if(DEFINED ENV{PYTORCH_ROCM_ARCH})
 | 
			
		||||
      list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
 | 
			
		||||
      if(NUM_ARCHS GREATER 1)
 | 
			
		||||
        message(WARNING "Building CK for multiple archs can increase build time considerably!
 | 
			
		||||
        Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
 | 
			
		||||
  if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
 | 
			
		||||
    set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
 | 
			
		||||
      if(USE_CK_FLASH_ATTENTION STREQUAL "1")
 | 
			
		||||
        if(DEFINED ENV{PYTORCH_ROCM_ARCH})
 | 
			
		||||
          list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
 | 
			
		||||
          if(NUM_ARCHS GREATER 1)
 | 
			
		||||
            message(WARNING "Building CK for multiple archs can increase build time considerably!
 | 
			
		||||
            Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
 | 
			
		||||
          endif()
 | 
			
		||||
        endif()
 | 
			
		||||
        message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
 | 
			
		||||
        message(STATUS "Generating CK kernel instances...")
 | 
			
		||||
        add_subdirectory(native/transformers/hip/flash_attn/ck)
 | 
			
		||||
        file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
 | 
			
		||||
        list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
 | 
			
		||||
        # FAv3 Generation
 | 
			
		||||
        add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
 | 
			
		||||
        file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
 | 
			
		||||
        list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
 | 
			
		||||
      endif()
 | 
			
		||||
    endif()
 | 
			
		||||
    message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled")
 | 
			
		||||
    message(STATUS "Generating CK kernel instances...")
 | 
			
		||||
    add_subdirectory(native/transformers/hip/flash_attn/ck)
 | 
			
		||||
    file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
 | 
			
		||||
    list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
 | 
			
		||||
    # FAv3 Generation
 | 
			
		||||
    add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
 | 
			
		||||
    file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
 | 
			
		||||
    list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
 | 
			
		||||
  endif()
 | 
			
		||||
  file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
 | 
			
		||||
  file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
 | 
			
		||||
@ -419,42 +416,40 @@ if(USE_CUDA)
 | 
			
		||||
endif()
 | 
			
		||||
 | 
			
		||||
if(USE_ROCM)
 | 
			
		||||
  if((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM)
 | 
			
		||||
    # NOTE: The PyTorch build does not actually add_subdirectory
 | 
			
		||||
    # third_party/composable_kernel or use it as a CMake library. What is used
 | 
			
		||||
    # is header only, so this should be ok, except that the CMake build generates
 | 
			
		||||
    # a ck/config.h. We just do that part here. Without this, the ck.h from the
 | 
			
		||||
    # ROCM SDK may get accidentally used instead.
 | 
			
		||||
    function(_pytorch_rocm_generate_ck_conf)
 | 
			
		||||
      set(CK_ENABLE_INT8 "ON")
 | 
			
		||||
      set(CK_ENABLE_FP16 "ON")
 | 
			
		||||
      set(CK_ENABLE_FP32 "ON")
 | 
			
		||||
      set(CK_ENABLE_FP64 "ON")
 | 
			
		||||
      set(CK_ENABLE_BF16 "ON")
 | 
			
		||||
      set(CK_ENABLE_FP8 "ON")
 | 
			
		||||
      set(CK_ENABLE_BF8 "ON")
 | 
			
		||||
      set(CK_USE_XDL "ON")
 | 
			
		||||
      set(CK_USE_WMMA "ON")
 | 
			
		||||
      configure_file(
 | 
			
		||||
        "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
 | 
			
		||||
        "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
 | 
			
		||||
        )
 | 
			
		||||
    endfunction()
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
 | 
			
		||||
    _pytorch_rocm_generate_ck_conf()
 | 
			
		||||
  endif()
 | 
			
		||||
  # NOTE: The PyTorch build does not actually add_subdirectory
 | 
			
		||||
  # third_party/composable_kernel or use it as a CMake library. What is used
 | 
			
		||||
  # is header only, so this should be ok, except that the CMake build generates
 | 
			
		||||
  # a ck/config.h. We just do that part here. Without this, the ck.h from the
 | 
			
		||||
  # ROCM SDK may get accidentally used instead.
 | 
			
		||||
  function(_pytorch_rocm_generate_ck_conf)
 | 
			
		||||
    set(CK_ENABLE_INT8 "ON")
 | 
			
		||||
    set(CK_ENABLE_FP16 "ON")
 | 
			
		||||
    set(CK_ENABLE_FP32 "ON")
 | 
			
		||||
    set(CK_ENABLE_FP64 "ON")
 | 
			
		||||
    set(CK_ENABLE_BF16 "ON")
 | 
			
		||||
    set(CK_ENABLE_FP8 "ON")
 | 
			
		||||
    set(CK_ENABLE_BF8 "ON")
 | 
			
		||||
    set(CK_USE_XDL "ON")
 | 
			
		||||
    set(CK_USE_WMMA "ON")
 | 
			
		||||
    configure_file(
 | 
			
		||||
      "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
 | 
			
		||||
      "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
 | 
			
		||||
      )
 | 
			
		||||
  endfunction()
 | 
			
		||||
  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
 | 
			
		||||
  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
 | 
			
		||||
  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
 | 
			
		||||
  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
 | 
			
		||||
  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
 | 
			
		||||
  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
 | 
			
		||||
  _pytorch_rocm_generate_ck_conf()
 | 
			
		||||
 | 
			
		||||
  # Next two lines are needed because TunableOp uses third-party/fmt
 | 
			
		||||
  list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
 | 
			
		||||
  list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
 | 
			
		||||
  if(USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA)
 | 
			
		||||
    list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
 | 
			
		||||
  endif()
 | 
			
		||||
if(USE_FLASH_ATTENTION)
 | 
			
		||||
  list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
 | 
			
		||||
endif()
 | 
			
		||||
  list(APPEND ATen_HIP_SRCS
 | 
			
		||||
    ${ATen_HIP_SRCS}
 | 
			
		||||
    ${hip_hip}
 | 
			
		||||
@ -464,17 +459,12 @@ if(USE_ROCM)
 | 
			
		||||
    ${native_quantized_hip_hip}
 | 
			
		||||
    ${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
 | 
			
		||||
  )
 | 
			
		||||
  if(NOT USE_ROCM_CK_GEMM)
 | 
			
		||||
  if(WIN32) # Windows doesn't support Composable Kernels
 | 
			
		||||
    file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
 | 
			
		||||
    file(GLOB native_hip_ck "native/hip/ck*.hip")
 | 
			
		||||
    exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
 | 
			
		||||
      ${native_hip_bgemm} ${native_hip_ck})
 | 
			
		||||
  endif()
 | 
			
		||||
  if(WIN32) # Windows doesn't support Composable Kernels and Triton
 | 
			
		||||
    exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
 | 
			
		||||
      ${native_transformers_hip_hip} ${native_transformers_hip_cpp})
 | 
			
		||||
  endif()
 | 
			
		||||
 | 
			
		||||
  # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
 | 
			
		||||
  list(APPEND all_hip_cpp
 | 
			
		||||
    ${native_nested_hip_cpp}
 | 
			
		||||
@ -709,10 +699,10 @@ endif()
 | 
			
		||||
if(USE_MPS)
 | 
			
		||||
    include(../../../cmake/Metal.cmake)
 | 
			
		||||
 | 
			
		||||
    set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h} ${native_sparse_mps_mm})
 | 
			
		||||
    set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h})
 | 
			
		||||
 | 
			
		||||
    if(CAN_COMPILE_METAL)
 | 
			
		||||
        foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal})
 | 
			
		||||
        foreach(SHADER ${native_mps_metal})
 | 
			
		||||
            cmake_path(GET SHADER STEM TGT_STEM)
 | 
			
		||||
            string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
 | 
			
		||||
            list(APPEND AIR_BASIC ${TGT_BASIC})
 | 
			
		||||
@ -727,7 +717,7 @@ if(USE_MPS)
 | 
			
		||||
        add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
 | 
			
		||||
    else()
 | 
			
		||||
        file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
 | 
			
		||||
        foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal})
 | 
			
		||||
        foreach(SHADER ${native_mps_metal})
 | 
			
		||||
            cmake_path(GET SHADER STEM TGT_STEM)
 | 
			
		||||
            string(CONCAT SHADER_HDR_NAME  "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h")
 | 
			
		||||
            metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME})
 | 
			
		||||
 | 
			
		||||
@ -480,9 +480,6 @@ at::BlasBackend Context::blasPreferredBackend() {
 | 
			
		||||
  // call site for blasPreferredBackend(), we set it to an actual value.
 | 
			
		||||
  if (blas_preferred_backend == at::BlasBackend::Default) {
 | 
			
		||||
    blas_preferred_backend = at::BlasBackend::Cublas;
 | 
			
		||||
    // This logic sits in the getter because it needs to validate
 | 
			
		||||
    // values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT
 | 
			
		||||
    // which initialize the backend without calling the setter
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    // AMD Instinct targets prefer hipblaslt
 | 
			
		||||
    static const bool hipblaslt_preferred = []() {
 | 
			
		||||
@ -512,10 +509,6 @@ at::BlasBackend Context::blasPreferredBackend() {
 | 
			
		||||
  // hipblaslt support for all archs is not as complete as hipblas
 | 
			
		||||
  if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
 | 
			
		||||
    static const bool hipblaslt_unsupported = []() {
 | 
			
		||||
      if(!hasCuBLASLt())
 | 
			
		||||
      {
 | 
			
		||||
          return true;
 | 
			
		||||
      }
 | 
			
		||||
      static const std::vector<std::string> archs = {
 | 
			
		||||
          "gfx90a", "gfx942",
 | 
			
		||||
#if ROCM_VERSION >= 60300
 | 
			
		||||
@ -541,24 +534,6 @@ at::BlasBackend Context::blasPreferredBackend() {
 | 
			
		||||
  return blas_preferred_backend;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool Context::ckSupported() {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  static const std::vector<std::string> supported_archs = {
 | 
			
		||||
    "gfx90a", "gfx942", "gfx950"
 | 
			
		||||
  };
 | 
			
		||||
  for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) {
 | 
			
		||||
    if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) {
 | 
			
		||||
      TORCH_WARN_ONCE(
 | 
			
		||||
        "Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return true;
 | 
			
		||||
#else
 | 
			
		||||
  return false;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Context::setBlasPreferredBackend(at::BlasBackend b) {
 | 
			
		||||
#ifdef _MSC_VER
 | 
			
		||||
  TORCH_WARN_ONCE(
 | 
			
		||||
@ -568,14 +543,8 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
 | 
			
		||||
#else
 | 
			
		||||
  TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
 | 
			
		||||
      "Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  static const bool ckSupportedFlag = ckSupported();
 | 
			
		||||
  static const bool hasCKGEMMFlag = hasCKGEMM();
 | 
			
		||||
  TORCH_CHECK((b != at::BlasBackend::Ck) || (ckSupportedFlag && hasCKGEMMFlag),
 | 
			
		||||
      "Cannot set preferred blas backend to CK since following conditions are not true: ",
 | 
			
		||||
      "architecture supported for CK: ", ckSupportedFlag,
 | 
			
		||||
      ", PyTorch built with CK GEMM support: ", hasCKGEMMFlag);
 | 
			
		||||
#endif
 | 
			
		||||
  TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
 | 
			
		||||
      "Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
 | 
			
		||||
  if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) {
 | 
			
		||||
    TORCH_WARN_ONCE(
 | 
			
		||||
      "torch.backends.cuda.preferred_blas_library is an experimental feature. "
 | 
			
		||||
@ -587,40 +556,35 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
at::ROCmFABackend Context::getROCmFAPreferredBackend() {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  // Set potential "Default" value so we don't have to interpret at call sites.
 | 
			
		||||
  // We use aotriton backend as the default, for now.
 | 
			
		||||
  if(rocm_fa_preferred_backend == at::ROCmFABackend::Default) {
 | 
			
		||||
    rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
 | 
			
		||||
  } else if (rocm_fa_preferred_backend == at::ROCmFABackend::Ck) {
 | 
			
		||||
    // This logic sits in the getter because it needs to validate
 | 
			
		||||
    // values set via env vars such as TORCH_ROCM_FA_PREFER_CK
 | 
			
		||||
    // which initialize the backend without calling the setter
 | 
			
		||||
    // Perform validity checking
 | 
			
		||||
    static const bool hasCKSDPAFlag = hasCKSDPA();
 | 
			
		||||
    static const bool ckSupportedFlag = ckSupported();
 | 
			
		||||
    if(!(hasCKSDPAFlag && ckSupportedFlag)){
 | 
			
		||||
      TORCH_WARN_ONCE(
 | 
			
		||||
        "Cannot set preferred SDPA backend to CK since following conditions are not true: ",
 | 
			
		||||
        "architecture supported for CK: ", ckSupportedFlag,
 | 
			
		||||
        ", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
 | 
			
		||||
      rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
at::ROCmFABackend Context::getROCmFAPreferredBackend() const {
 | 
			
		||||
  return rocm_fa_preferred_backend;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
 | 
			
		||||
 | 
			
		||||
  // TODO: add plumbing for hasCK for validity checking
 | 
			
		||||
  TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(),
 | 
			
		||||
      "Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm.");
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  static const bool hasCKSDPAFlag = hasCKSDPA();
 | 
			
		||||
  static const bool ckSupportedFlag = ckSupported();
 | 
			
		||||
  TORCH_CHECK((b != at::ROCmFABackend::Ck) || (hasCKSDPAFlag && ckSupportedFlag),
 | 
			
		||||
      "Cannot set preferred SDPA backend to CK since following conditions are not true: ",
 | 
			
		||||
      "architecture supported for CK: ", ckSupportedFlag,
 | 
			
		||||
      ", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
 | 
			
		||||
  if(b == at::ROCmFABackend::Ck) {
 | 
			
		||||
    static const bool ck_unsupported = []() {
 | 
			
		||||
      static const std::vector<std::string> archs = {
 | 
			
		||||
          "gfx90a",  "gfx942"
 | 
			
		||||
      };
 | 
			
		||||
      for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
 | 
			
		||||
        if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
 | 
			
		||||
          TORCH_WARN_ONCE(
 | 
			
		||||
            "Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
 | 
			
		||||
          return true;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      return false;
 | 
			
		||||
    }();
 | 
			
		||||
    if(!ck_unsupported) rocm_fa_preferred_backend = b;
 | 
			
		||||
  }
 | 
			
		||||
  else {
 | 
			
		||||
     rocm_fa_preferred_backend = b;
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
  rocm_fa_preferred_backend = b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -132,7 +132,6 @@ class TORCH_API Context {
 | 
			
		||||
  static bool hasKleidiAI();
 | 
			
		||||
  static bool hasLAPACK();
 | 
			
		||||
  static bool hasMKLDNN();
 | 
			
		||||
  static bool ckSupported();
 | 
			
		||||
  static bool hasMAGMA() {
 | 
			
		||||
    return detail::getCUDAHooks().hasMAGMA();
 | 
			
		||||
  }
 | 
			
		||||
@ -163,12 +162,6 @@ class TORCH_API Context {
 | 
			
		||||
  static bool hasROCM() {
 | 
			
		||||
    return detail::getCUDAHooks().hasROCM();
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasCKSDPA() {
 | 
			
		||||
    return detail::getCUDAHooks().hasCKSDPA();
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasCKGEMM() {
 | 
			
		||||
    return detail::getCUDAHooks().hasCKGEMM();
 | 
			
		||||
  }
 | 
			
		||||
  static bool hasHIP() {
 | 
			
		||||
    return detail::getHIPHooks().hasHIP();
 | 
			
		||||
  }
 | 
			
		||||
@ -259,7 +252,7 @@ class TORCH_API Context {
 | 
			
		||||
  at::BlasBackend blasPreferredBackend();
 | 
			
		||||
  void setBlasPreferredBackend(at::BlasBackend);
 | 
			
		||||
 | 
			
		||||
  at::ROCmFABackend getROCmFAPreferredBackend();
 | 
			
		||||
  at::ROCmFABackend getROCmFAPreferredBackend() const;
 | 
			
		||||
  void setROCmFAPreferredBackend(at::ROCmFABackend);
 | 
			
		||||
 | 
			
		||||
  // Note [Enabling Deterministic Operations]
 | 
			
		||||
 | 
			
		||||
@ -31,9 +31,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
 | 
			
		||||
      return at::globalContext().getPinnedMemoryAllocator(opt_device_type);
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(
 | 
			
		||||
          false,
 | 
			
		||||
          "pin_memory=True requires a CUDA or other accelerator backend; "
 | 
			
		||||
          "no pinned memory allocator is available on this system.")
 | 
			
		||||
          false, "Need to provide pin_memory allocator to use pin memory.")
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -239,7 +239,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
 | 
			
		||||
  KERNEL_MPS(scaled_dot_product_attention, lower_precision_fp)
 | 
			
		||||
 | 
			
		||||
  // fp32
 | 
			
		||||
  KERNEL_MPS(conv_transpose3d, input, fp32)
 | 
			
		||||
  KERNEL_MPS(acos, fp32)
 | 
			
		||||
  KERNEL_MPS(asin, fp32)
 | 
			
		||||
  KERNEL_MPS(cosh, fp32)
 | 
			
		||||
 | 
			
		||||
@ -97,8 +97,6 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
 | 
			
		||||
        return ComplexType::get();
 | 
			
		||||
      case Tag::Int:
 | 
			
		||||
        return IntType::get();
 | 
			
		||||
      case Tag::UInt:
 | 
			
		||||
        return IntType::get();
 | 
			
		||||
      case Tag::SymInt:
 | 
			
		||||
        return c10::SymIntType::get();
 | 
			
		||||
      case Tag::SymFloat:
 | 
			
		||||
@ -322,8 +320,6 @@ IValue IValue::equals(const IValue& rhs) const {
 | 
			
		||||
      return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
 | 
			
		||||
    case Tag::Int:
 | 
			
		||||
      return rhs.isInt() && lhs.toInt() == rhs.toInt();
 | 
			
		||||
    case Tag::UInt:
 | 
			
		||||
      return rhs.isUnsigned() && lhs.toUInt() == rhs.toUInt();
 | 
			
		||||
    case Tag::SymInt:
 | 
			
		||||
      return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt();
 | 
			
		||||
    case Tag::SymFloat:
 | 
			
		||||
@ -383,8 +379,6 @@ size_t IValue::hash(const IValue& v) {
 | 
			
		||||
    case Tag::Int:
 | 
			
		||||
      return c10::get_hash(v.payload.u.as_int);
 | 
			
		||||
    // NB: these are technically strict aliasing violations
 | 
			
		||||
    case Tag::UInt:
 | 
			
		||||
      return c10::get_hash(v.payload.u.as_int);
 | 
			
		||||
    case Tag::SymInt:
 | 
			
		||||
      return c10::get_hash(v.payload.u.as_int);
 | 
			
		||||
    case Tag::SymFloat:
 | 
			
		||||
@ -812,8 +806,6 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
 | 
			
		||||
      return printComplex(out, v);
 | 
			
		||||
    } case IValue::Tag::Int:
 | 
			
		||||
      return out << v.toInt();
 | 
			
		||||
    case IValue::Tag::UInt:
 | 
			
		||||
      return out << v.toUInt();
 | 
			
		||||
    case IValue::Tag::SymInt:
 | 
			
		||||
      return out << v.toSymInt();
 | 
			
		||||
    case IValue::Tag::SymFloat:
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,6 @@
 | 
			
		||||
#include <c10/macros/Export.h>
 | 
			
		||||
#include <c10/util/MaybeOwned.h>
 | 
			
		||||
#include <c10/util/intrusive_ptr.h>
 | 
			
		||||
#include <limits>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <unordered_set>
 | 
			
		||||
@ -161,7 +160,6 @@ struct Capsule {
 | 
			
		||||
  _(Double)                  \
 | 
			
		||||
  _(ComplexDouble)           \
 | 
			
		||||
  _(Int)                     \
 | 
			
		||||
  _(UInt)                    \
 | 
			
		||||
  _(SymInt)                  \
 | 
			
		||||
  _(SymFloat)                \
 | 
			
		||||
  _(SymBool)                 \
 | 
			
		||||
@ -655,29 +653,6 @@ struct TORCH_API IValue final {
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Unsigned
 | 
			
		||||
  IValue(uint64_t u) : tag( u <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt) {
 | 
			
		||||
    payload.u.as_uint = u;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  // See Note [Meaning of HAS_u]
 | 
			
		||||
  // IValue type model closely follows that of c10::Scalar
 | 
			
		||||
  // Where all integers are upcast to 64-bit representation, and `as_int` is used as default
 | 
			
		||||
  // representation unless value could not be represented as signed int
 | 
			
		||||
  bool isUnsigned() const {
 | 
			
		||||
    return Tag::UInt == tag || (Tag::Int == tag && payload.u.as_int >= 0);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  uint64_t toUInt() const {
 | 
			
		||||
    if (isUnsigned()) {
 | 
			
		||||
      return payload.u.as_uint;
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(0, "expected unsigned int");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  // Bool
 | 
			
		||||
  IValue(bool b) : tag(Tag::Bool) {
 | 
			
		||||
#if defined(__clang__) && defined(__x86_64__)
 | 
			
		||||
@ -918,14 +893,8 @@ struct TORCH_API IValue final {
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
 | 
			
		||||
          s.isIntegral(false), "Unknown type in Scalar");
 | 
			
		||||
      if (s.isUnsigned()) {
 | 
			
		||||
        const auto val = s.toUInt64();
 | 
			
		||||
        payload.u.as_uint = val;
 | 
			
		||||
        tag = val <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt;
 | 
			
		||||
      } else {
 | 
			
		||||
        payload.u.as_int = s.toLong();
 | 
			
		||||
        tag = Tag::Int;
 | 
			
		||||
      }
 | 
			
		||||
      tag = Tag::Int;
 | 
			
		||||
      payload.u.as_int = s.toLong();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -949,8 +918,6 @@ struct TORCH_API IValue final {
 | 
			
		||||
      return toSymFloat();
 | 
			
		||||
    else if (isSymBool())
 | 
			
		||||
      return toSymBool();
 | 
			
		||||
    else if (isUnsigned())
 | 
			
		||||
      return toUInt();
 | 
			
		||||
    TORCH_CHECK(false, "IValue is not a Scalar");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -1280,8 +1247,6 @@ struct TORCH_API IValue final {
 | 
			
		||||
        return true;
 | 
			
		||||
      case Tag::Int:
 | 
			
		||||
        return false;
 | 
			
		||||
      case Tag::UInt:
 | 
			
		||||
        return false;
 | 
			
		||||
      case Tag::SymInt:
 | 
			
		||||
        return true;
 | 
			
		||||
      case Tag::SymFloat:
 | 
			
		||||
@ -1378,8 +1343,6 @@ struct TORCH_API IValue final {
 | 
			
		||||
    union TriviallyCopyablePayload {
 | 
			
		||||
      TriviallyCopyablePayload() : as_int(0) {}
 | 
			
		||||
      int64_t as_int;
 | 
			
		||||
      // See Note [Meaning of HAS_u]
 | 
			
		||||
      uint64_t as_uint;
 | 
			
		||||
      double as_double;
 | 
			
		||||
      bool as_bool;
 | 
			
		||||
      // Invariant: never nullptr; null state is represented as
 | 
			
		||||
 | 
			
		||||
@ -832,7 +832,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
 | 
			
		||||
      bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
 | 
			
		||||
  else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
 | 
			
		||||
    at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
 | 
			
		||||
  }
 | 
			
		||||
@ -1273,7 +1273,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
 | 
			
		||||
    gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
 | 
			
		||||
  else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
 | 
			
		||||
    at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
 | 
			
		||||
  }
 | 
			
		||||
@ -1289,7 +1289,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
 | 
			
		||||
  if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
 | 
			
		||||
    gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
 | 
			
		||||
  }
 | 
			
		||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
 | 
			
		||||
  else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
 | 
			
		||||
    if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
 | 
			
		||||
      gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
 | 
			
		||||
@ -1341,7 +1341,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
 | 
			
		||||
  if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
 | 
			
		||||
    gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
 | 
			
		||||
  }
 | 
			
		||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
 | 
			
		||||
  else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
 | 
			
		||||
    at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
 | 
			
		||||
  }
 | 
			
		||||
@ -1357,7 +1357,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
 | 
			
		||||
  if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
 | 
			
		||||
    gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
 | 
			
		||||
  }
 | 
			
		||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
 | 
			
		||||
  else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
 | 
			
		||||
    at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -207,27 +207,6 @@ bool CUDAHooks::hasCuBLASLt() const {
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
bool CUDAHooks::hasCKSDPA() const {
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
    return false;
 | 
			
		||||
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_SDPA)
 | 
			
		||||
    return true;
 | 
			
		||||
#else
 | 
			
		||||
    return false;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool CUDAHooks::hasCKGEMM() const {
 | 
			
		||||
#if !defined(USE_ROCM)
 | 
			
		||||
    return false;
 | 
			
		||||
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
    return true;
 | 
			
		||||
#else
 | 
			
		||||
    return false;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool CUDAHooks::hasROCM() const {
 | 
			
		||||
  // Currently, this is same as `compiledWithMIOpen`.
 | 
			
		||||
  // But in future if there are ROCm builds without MIOpen,
 | 
			
		||||
 | 
			
		||||
@ -31,8 +31,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
 | 
			
		||||
  bool hasCuSOLVER() const override;
 | 
			
		||||
  bool hasCuBLASLt() const override;
 | 
			
		||||
  bool hasROCM() const override;
 | 
			
		||||
  bool hasCKSDPA() const override;
 | 
			
		||||
  bool hasCKGEMM() const override;
 | 
			
		||||
  const at::cuda::NVRTC& nvrtc() const override;
 | 
			
		||||
  DeviceIndex current_device() const override;
 | 
			
		||||
  bool isBuilt() const override {return true;}
 | 
			
		||||
 | 
			
		||||
@ -118,14 +118,6 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool hasCKSDPA() const {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool hasCKGEMM() const {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual const at::cuda::NVRTC& nvrtc() const {
 | 
			
		||||
    TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -21,10 +21,6 @@ bool isMTIAHooksBuilt() {
 | 
			
		||||
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
bool MTIAHooksInterface::isAvailable() const {
 | 
			
		||||
  return detail::isMTIAHooksBuilt() && detail::getMTIAHooks().deviceCount() > 0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs)
 | 
			
		||||
 | 
			
		||||
} // namespace at
 | 
			
		||||
 | 
			
		||||
@ -149,8 +149,6 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
 | 
			
		||||
    FAIL_MTIAHOOKS_FUNC(__func__);
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool isAvailable() const override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TORCH_API MTIAHooksArgs {};
 | 
			
		||||
 | 
			
		||||
@ -43,6 +43,7 @@ TensorBase empty_mps(
 | 
			
		||||
    int64_t nelements = c10::multiply_integers(size);
 | 
			
		||||
    auto dtype = dtype_or_default(dtype_opt);
 | 
			
		||||
    TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
 | 
			
		||||
    TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer");
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    auto dtype_meta = scalarTypeToTypeMeta(dtype);
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,11 @@ namespace at::mps {
 | 
			
		||||
 | 
			
		||||
// Helper enum to check if a MPSGraph op is supported in a given macOS version
 | 
			
		||||
enum class MacOSVersion : uint32_t {
 | 
			
		||||
  MACOS_VER_14_4_PLUS = 0,
 | 
			
		||||
  MACOS_VER_13_1_PLUS = 0,
 | 
			
		||||
  MACOS_VER_13_2_PLUS,
 | 
			
		||||
  MACOS_VER_13_3_PLUS,
 | 
			
		||||
  MACOS_VER_14_0_PLUS,
 | 
			
		||||
  MACOS_VER_14_4_PLUS,
 | 
			
		||||
  MACOS_VER_15_0_PLUS,
 | 
			
		||||
  MACOS_VER_15_1_PLUS,
 | 
			
		||||
  MACOS_VER_15_2_PLUS,
 | 
			
		||||
 | 
			
		||||
@ -32,11 +32,11 @@ MPSDevice::~MPSDevice() {
 | 
			
		||||
 | 
			
		||||
MPSDevice::MPSDevice() : _mtl_device(nil) {
 | 
			
		||||
  // Check that MacOS 13.0+ version of MPS framework is available
 | 
			
		||||
  // Create the MPSGraph and check method introduced in 14.0
 | 
			
		||||
  // Create the MPSGraph and check method introduced in 13.0
 | 
			
		||||
  // which is used by MPS backend.
 | 
			
		||||
  id mpsCD = NSClassFromString(@"MPSGraph");
 | 
			
		||||
 | 
			
		||||
  if ([mpsCD instancesRespondToSelector:@selector(HermiteanToRealFFTWithTensor:axes:descriptor:name:)] == NO) {
 | 
			
		||||
  if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -66,12 +66,24 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
 | 
			
		||||
          isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}];
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  static bool _macos_13_1_plus = is_os_version_at_least(13, 1);
 | 
			
		||||
  static bool _macos_13_2_plus = is_os_version_at_least(13, 2);
 | 
			
		||||
  static bool _macos_13_3_plus = is_os_version_at_least(13, 3);
 | 
			
		||||
  static bool _macos_14_0_plus = is_os_version_at_least(14, 0);
 | 
			
		||||
  static bool _macos_14_4_plus = is_os_version_at_least(14, 4);
 | 
			
		||||
  static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
 | 
			
		||||
  static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
 | 
			
		||||
  static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
 | 
			
		||||
 | 
			
		||||
  switch (version) {
 | 
			
		||||
    case MacOSVersion::MACOS_VER_13_1_PLUS:
 | 
			
		||||
      return _macos_13_1_plus;
 | 
			
		||||
    case MacOSVersion::MACOS_VER_13_2_PLUS:
 | 
			
		||||
      return _macos_13_2_plus;
 | 
			
		||||
    case MacOSVersion::MACOS_VER_13_3_PLUS:
 | 
			
		||||
      return _macos_13_3_plus;
 | 
			
		||||
    case MacOSVersion::MACOS_VER_14_0_PLUS:
 | 
			
		||||
      return _macos_14_0_plus;
 | 
			
		||||
    case MacOSVersion::MACOS_VER_14_4_PLUS:
 | 
			
		||||
      return _macos_14_4_plus;
 | 
			
		||||
    case MacOSVersion::MACOS_VER_15_0_PLUS:
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,7 @@ bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
 | 
			
		||||
    case 14:
 | 
			
		||||
      switch (minor) {
 | 
			
		||||
        case 0:
 | 
			
		||||
          return true;
 | 
			
		||||
          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
 | 
			
		||||
        case 4:
 | 
			
		||||
          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
 | 
			
		||||
        default:
 | 
			
		||||
@ -42,7 +42,19 @@ bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
 | 
			
		||||
          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
 | 
			
		||||
      }
 | 
			
		||||
    case 13:
 | 
			
		||||
      return true;
 | 
			
		||||
      switch (minor) {
 | 
			
		||||
        case 0:
 | 
			
		||||
          return true;
 | 
			
		||||
        case 1:
 | 
			
		||||
          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
 | 
			
		||||
        case 2:
 | 
			
		||||
          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
 | 
			
		||||
        case 3:
 | 
			
		||||
          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
        default:
 | 
			
		||||
          TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+");
 | 
			
		||||
          return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
      }
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false");
 | 
			
		||||
      return false;
 | 
			
		||||
 | 
			
		||||
@ -51,7 +51,7 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *
 | 
			
		||||
// brgemm_pack_B is changed to transform and the setting of brgemm beta is changed to set_add_C
 | 
			
		||||
#if (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR == 5)
 | 
			
		||||
#define ONEDNN_UKERNEL_1
 | 
			
		||||
#elif ((IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR >= 6) || (IDEEP_VERSION_MAJOR > 3))
 | 
			
		||||
#elif (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 6)
 | 
			
		||||
#define ONEDNN_UKERNEL_2
 | 
			
		||||
#endif
 | 
			
		||||
#if ((defined(ONEDNN_UKERNEL_1) || defined(ONEDNN_UKERNEL_2)) && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))))
 | 
			
		||||
 | 
			
		||||
@ -206,16 +206,6 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
 | 
			
		||||
// B Base pointer to a tensor B.
 | 
			
		||||
// C Pointer to a tensor C (accumulation buffer).
 | 
			
		||||
// Note only batch size 1 is used currently
 | 
			
		||||
 | 
			
		||||
// Define macros for available brgemm APIs
 | 
			
		||||
// so that callers can determine which APIs are available
 | 
			
		||||
#define CPUBLAS_BRGEMM_F16F16F32 // half * half -> float
 | 
			
		||||
#define CPUBLAS_BRGEMM_BF16BF16F32 // bfloat16 * bfloat16 -> float
 | 
			
		||||
#define CPUBLAS_BRGEMM_F32F32F32 // float * float -> float
 | 
			
		||||
#define CPUBLAS_BRGEMM_U8U8I32 // unsigned char * unsigned char -> int32
 | 
			
		||||
#define CPUBLAS_BRGEMM_U8I8I32 // unsigned char * signed char -> int32
 | 
			
		||||
#define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32
 | 
			
		||||
 | 
			
		||||
TORCH_API void brgemm(
 | 
			
		||||
    int64_t M,
 | 
			
		||||
    int64_t N,
 | 
			
		||||
 | 
			
		||||
@ -3,7 +3,6 @@
 | 
			
		||||
#include <ATen/Config.h>
 | 
			
		||||
#include <ATen/Parallel.h>
 | 
			
		||||
#include <ATen/TensorOperators.h>
 | 
			
		||||
#include <ATen/native/CanUse32BitIndexMath.h>
 | 
			
		||||
#include <ATen/native/ConvolutionMM3d.h>
 | 
			
		||||
#include <ATen/native/ConvUtils.h>
 | 
			
		||||
#include <ATen/native/Pool.h>
 | 
			
		||||
@ -464,7 +463,7 @@ struct ConvParams {
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
    // native kernel doesn't support 64-bit non-splittable case
 | 
			
		||||
    if (cudnn_enabled && !(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
 | 
			
		||||
    if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) {
 | 
			
		||||
      static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
 | 
			
		||||
      if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
 | 
			
		||||
        TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"
 | 
			
		||||
 | 
			
		||||
@ -282,14 +282,6 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd(
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // not coalsced, so now let try to capture lane-matches...
 | 
			
		||||
 | 
			
		||||
    if (numel > 16 /*<-hueristic threshold*/ * 64 ) {
 | 
			
		||||
      // well shucks, unlikely to capture same-dest atomics in a wave.
 | 
			
		||||
      // fall back to direct fastAtomic...
 | 
			
		||||
      fastAtomicAdd(self_ptr, index, numel, value, true);
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // __activemask() -- finds the set of threads in the warp that are about to perform atomicAdd
 | 
			
		||||
    // __match_any_sync() -- returns bit mask of the threads that have same dest addr
 | 
			
		||||
    auto mask = __match_any_sync(__activemask(), (int64_t)dst);
 | 
			
		||||
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@ -70,31 +70,4 @@ void run_cudnn_SDP_bprop(
 | 
			
		||||
    const Tensor& dropoutseed,
 | 
			
		||||
    const Tensor& dropoutoffset);
 | 
			
		||||
 | 
			
		||||
void run_cudnn_SDP_bprop_nestedtensor(
 | 
			
		||||
    int64_t b,
 | 
			
		||||
    int64_t h_q,
 | 
			
		||||
    int64_t h_k,
 | 
			
		||||
    int64_t h_v,
 | 
			
		||||
    int64_t s_q,
 | 
			
		||||
    int64_t s_kv,
 | 
			
		||||
    int64_t d_qk,
 | 
			
		||||
    int64_t d_v,
 | 
			
		||||
    float scaling_factor,
 | 
			
		||||
    bool is_causal,
 | 
			
		||||
    float dropout_probability,
 | 
			
		||||
    const Tensor& cum_seqlen_q,
 | 
			
		||||
    const Tensor& cum_seqlen_kv,
 | 
			
		||||
    const Tensor& q,
 | 
			
		||||
    const Tensor& k,
 | 
			
		||||
    const Tensor& v,
 | 
			
		||||
    const std::optional<Tensor>& attn_bias,
 | 
			
		||||
    const Tensor& o,
 | 
			
		||||
    const Tensor& dO,
 | 
			
		||||
    const Tensor& softmaxstats,
 | 
			
		||||
    Tensor& dQ,
 | 
			
		||||
    Tensor& dK,
 | 
			
		||||
    Tensor& dV,
 | 
			
		||||
    const Tensor& dropoutseed,
 | 
			
		||||
    const Tensor& dropoutoffset);
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,6 @@ inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
 | 
			
		||||
  static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
template <>
 | 
			
		||||
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double));
 | 
			
		||||
template <>
 | 
			
		||||
@ -19,7 +18,7 @@ template <>
 | 
			
		||||
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
 | 
			
		||||
template <>
 | 
			
		||||
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
#undef __HIP_NO_HALF_CONVERSIONS__
 | 
			
		||||
#include <ATen/native/hip/ck_gemm.h>
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
#include <ATen/native/hip/ck_gemm.h>
 | 
			
		||||
#include <ATen/native/hip/ck_gemm_template.h>
 | 
			
		||||
#include <ck/utility/sequence.hpp>
 | 
			
		||||
 | 
			
		||||
@ -782,4 +781,3 @@ void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
#endif // USE_ROCM_CK_GEMM
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
#undef __HIP_NO_HALF_CONVERSIONS__
 | 
			
		||||
 | 
			
		||||
#include <ATen/native/hip/ck_gemm.h>
 | 
			
		||||
#if defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
#include <ATen/native/hip/ck_gemm_template.h>
 | 
			
		||||
#include <ck/utility/sequence.hpp>
 | 
			
		||||
 | 
			
		||||
@ -485,4 +484,3 @@ void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
#endif // USE_ROCM_CK_GEMM
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,6 @@
 | 
			
		||||
#undef __HIP_NO_HALF_CONVERSIONS__
 | 
			
		||||
 | 
			
		||||
#include <ATen/native/hip/ck_gemm.h>
 | 
			
		||||
#if defined(USE_ROCM_CK_GEMM)
 | 
			
		||||
#include <ATen/native/hip/ck_gemm_template.h>
 | 
			
		||||
 | 
			
		||||
#include <ck/utility/sequence.hpp>
 | 
			
		||||
@ -607,4 +606,3 @@ void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
#endif // USE_ROCM_CK_GEMM
 | 
			
		||||
 | 
			
		||||
@ -88,8 +88,14 @@ std::string getArrayRefString(const IntArrayRef s);
 | 
			
		||||
// use has_storage() on the returned tensor to determine if src actually is a view
 | 
			
		||||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
 | 
			
		||||
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
 | 
			
		||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input);
 | 
			
		||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input);
 | 
			
		||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
 | 
			
		||||
                               MPSGraphTensor* inputTensor,
 | 
			
		||||
                               const TensorBase& input,
 | 
			
		||||
                               bool includesInt64 = false);
 | 
			
		||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
 | 
			
		||||
                                 MPSGraphTensor* inputTensor,
 | 
			
		||||
                                 const TensorBase& input,
 | 
			
		||||
                                 bool includesInt64 = false);
 | 
			
		||||
 | 
			
		||||
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray);
 | 
			
		||||
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
 | 
			
		||||
@ -429,6 +435,14 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(M
 | 
			
		||||
// Common math operations
 | 
			
		||||
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
 | 
			
		||||
 | 
			
		||||
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name)                                            \
 | 
			
		||||
  if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) {                                                        \
 | 
			
		||||
    TORCH_WARN_ONCE(                                                                                                     \
 | 
			
		||||
        "MPS: no support for int64 for ",                                                                                \
 | 
			
		||||
        op_name,                                                                                                         \
 | 
			
		||||
        ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Returns distance from lowest to highest element offset in given tensor.
 | 
			
		||||
 */
 | 
			
		||||
@ -604,6 +618,10 @@ inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds,
 | 
			
		||||
  runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool supportsComplex() {
 | 
			
		||||
  return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// MPS yet to support double types, but starting from MacOS 14, supports bfloat16
 | 
			
		||||
inline bool supportedFloatingType(ScalarType dtype) {
 | 
			
		||||
  return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
 | 
			
		||||
@ -615,7 +633,7 @@ inline bool supportedFloatingType(const TensorBase& t) {
 | 
			
		||||
 | 
			
		||||
inline bool supportedFloatingOrComplexType(ScalarType dtype) {
 | 
			
		||||
  if (dtype == kComplexFloat || dtype == kComplexHalf) {
 | 
			
		||||
    return true;
 | 
			
		||||
    return supportsComplex();
 | 
			
		||||
  }
 | 
			
		||||
  return supportedFloatingType(dtype);
 | 
			
		||||
}
 | 
			
		||||
@ -623,6 +641,11 @@ inline bool supportedFloatingOrComplexType(const TensorBase& t) {
 | 
			
		||||
  return supportedFloatingOrComplexType(t.scalar_type());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void checkSupportsBFloat16() {
 | 
			
		||||
  TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
 | 
			
		||||
                   "MPS bfloat16 type is supported on MacOS 14.0 or newer.");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool needsGather(const TensorBase& t) {
 | 
			
		||||
  static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
 | 
			
		||||
  return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset());
 | 
			
		||||
 | 
			
		||||
@ -89,6 +89,10 @@ void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds,
 | 
			
		||||
  mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static inline void checkSupportsComplex() {
 | 
			
		||||
  TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer.");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
MPSDataType getMPSDataType(ScalarType scalar_type) {
 | 
			
		||||
  switch (scalar_type) {
 | 
			
		||||
    case ScalarType::Float:
 | 
			
		||||
@ -96,6 +100,7 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
 | 
			
		||||
    case ScalarType::Half:
 | 
			
		||||
      return MPSDataTypeFloat16;
 | 
			
		||||
    case ScalarType::BFloat16:
 | 
			
		||||
      checkSupportsBFloat16();
 | 
			
		||||
      return MPSDataTypeBFloat16;
 | 
			
		||||
    case ScalarType::Int:
 | 
			
		||||
      return MPSDataTypeInt32;
 | 
			
		||||
@ -114,8 +119,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
 | 
			
		||||
                       "Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
 | 
			
		||||
                       "Please use float32 instead.")
 | 
			
		||||
    case ScalarType::ComplexHalf:
 | 
			
		||||
      checkSupportsComplex();
 | 
			
		||||
      return MPSDataTypeComplexFloat16;
 | 
			
		||||
    case ScalarType::ComplexFloat:
 | 
			
		||||
      checkSupportsComplex();
 | 
			
		||||
      return MPSDataTypeComplexFloat32;
 | 
			
		||||
    // Unsigned types
 | 
			
		||||
    case ScalarType::UInt64:
 | 
			
		||||
@ -133,10 +140,16 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
 | 
			
		||||
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
 | 
			
		||||
// Int32, Half and Float32 types. These utilities are to help cast to these
 | 
			
		||||
// types.
 | 
			
		||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) {
 | 
			
		||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
 | 
			
		||||
                               MPSGraphTensor* inputTensor,
 | 
			
		||||
                               const TensorBase& input,
 | 
			
		||||
                               bool includesInt64) {
 | 
			
		||||
  MPSDataType dataType = getMPSDataType(input.scalar_type());
 | 
			
		||||
  bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) &&
 | 
			
		||||
      (dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64);
 | 
			
		||||
  bool condition =
 | 
			
		||||
      (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
 | 
			
		||||
  if (includesInt64) {
 | 
			
		||||
    condition = condition && (dataType != MPSDataTypeInt64);
 | 
			
		||||
  }
 | 
			
		||||
  if (condition) {
 | 
			
		||||
    dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
 | 
			
		||||
    return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
 | 
			
		||||
@ -147,10 +160,16 @@ MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor,
 | 
			
		||||
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
 | 
			
		||||
// Int32, Half and Float32 types. These utilities are to help cast from these
 | 
			
		||||
// types.
 | 
			
		||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) {
 | 
			
		||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
 | 
			
		||||
                                 MPSGraphTensor* inputTensor,
 | 
			
		||||
                                 const TensorBase& input,
 | 
			
		||||
                                 bool includesInt64) {
 | 
			
		||||
  MPSDataType dataType = getMPSDataType(input.scalar_type());
 | 
			
		||||
  bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) &&
 | 
			
		||||
      (dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64);
 | 
			
		||||
  bool condition =
 | 
			
		||||
      (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
 | 
			
		||||
  if (includesInt64) {
 | 
			
		||||
    condition = condition && (dataType != MPSDataTypeInt64);
 | 
			
		||||
  }
 | 
			
		||||
  if (condition) {
 | 
			
		||||
    inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
 | 
			
		||||
  }
 | 
			
		||||
@ -167,6 +186,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
 | 
			
		||||
    case ScalarType::Half:
 | 
			
		||||
      return MPSDataTypeFloat16;
 | 
			
		||||
    case ScalarType::BFloat16:
 | 
			
		||||
      checkSupportsBFloat16();
 | 
			
		||||
      return MPSDataTypeBFloat16;
 | 
			
		||||
    case ScalarType::Int:
 | 
			
		||||
      return MPSDataTypeInt32;
 | 
			
		||||
@ -181,11 +201,13 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
 | 
			
		||||
    case ScalarType::Bool:
 | 
			
		||||
      return MPSDataTypeBool;
 | 
			
		||||
    case ScalarType::ComplexHalf:
 | 
			
		||||
      checkSupportsComplex();
 | 
			
		||||
      return MPSDataTypeComplexFloat16;
 | 
			
		||||
    // This is an intentional fallthrough supporting ComplexDouble for Scalar
 | 
			
		||||
    // types as they are casted to Complex64 currently.
 | 
			
		||||
    case ScalarType::ComplexDouble:
 | 
			
		||||
    case ScalarType::ComplexFloat:
 | 
			
		||||
      checkSupportsComplex();
 | 
			
		||||
      return MPSDataTypeComplexFloat32;
 | 
			
		||||
    // Unsigned types
 | 
			
		||||
    case ScalarType::UInt64:
 | 
			
		||||
@ -245,6 +267,7 @@ std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {
 | 
			
		||||
    case ScalarType::Half:
 | 
			
		||||
      return "half";
 | 
			
		||||
    case ScalarType::BFloat16:
 | 
			
		||||
      checkSupportsBFloat16();
 | 
			
		||||
      return "bfloat";
 | 
			
		||||
    case ScalarType::Int:
 | 
			
		||||
      return "int";
 | 
			
		||||
@ -856,7 +879,9 @@ id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
 | 
			
		||||
  MTLCompileOptions* options = compile_options;
 | 
			
		||||
  if (!options) {
 | 
			
		||||
    options = [[MTLCompileOptions new] autorelease];
 | 
			
		||||
    [options setLanguageVersion:MTLLanguageVersion3_1];
 | 
			
		||||
    // Need 3.0 for atomic oprations, 3.1 introduces bfloat support
 | 
			
		||||
    [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
 | 
			
		||||
                                                                                        : MTLLanguageVersion3_0];
 | 
			
		||||
    if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) {
 | 
			
		||||
      options.mathMode = fast_math ? MTLMathModeFast : MTLMathModeSafe;
 | 
			
		||||
      options.mathFloatingPointFunctions =
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,29 @@
 | 
			
		||||
using namespace metal;
 | 
			
		||||
using namespace c10::metal;
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
namespace metal {
 | 
			
		||||
// There are no atomic 64-bit add in Metal yet, but this implements a consistent
 | 
			
		||||
// add I.e. if multiple threads are modify the same 64-bit value, results stored
 | 
			
		||||
// at the address will eventually be equal to its original value plus sum of all
 | 
			
		||||
// operands
 | 
			
		||||
template <>
 | 
			
		||||
struct AtomicType<long> {
 | 
			
		||||
  using type = ::metal::atomic<uint>;
 | 
			
		||||
  static inline void atomic_add(device type* data, long offset, long value) {
 | 
			
		||||
    const auto value_bits = as_type<ulong>(value);
 | 
			
		||||
    const uint low = static_cast<uint>(value_bits);
 | 
			
		||||
    uint high = static_cast<uint>(value_bits >> 32);
 | 
			
		||||
    auto ptr = data + (offset << 1);
 | 
			
		||||
    auto old_low = atomic_fetch_add_explicit(ptr, low, memory_order_relaxed);
 | 
			
		||||
    high += (old_low + low < old_low) ? 1 : 0;
 | 
			
		||||
    atomic_fetch_add_explicit(ptr + 1, high, memory_order_relaxed);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace metal
 | 
			
		||||
} // namespace c10
 | 
			
		||||
 | 
			
		||||
struct IndexAB {
 | 
			
		||||
  constant int64_t* indexArray;
 | 
			
		||||
};
 | 
			
		||||
@ -211,15 +234,13 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial);
 | 
			
		||||
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, float, float);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, half, half);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, long, long);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, int, int);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, short, short);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, char, char);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, bool, bool);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, float2, float2);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, half2, half2);
 | 
			
		||||
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
 | 
			
		||||
 | 
			
		||||
template <typename StridesT, typename DataT>
 | 
			
		||||
kernel void kernel_index_offsets(
 | 
			
		||||
 | 
			
		||||
@ -68,37 +68,6 @@ kernel void matmul(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void addmm(
 | 
			
		||||
    constant T* mat1Data [[buffer(0)]],
 | 
			
		||||
    constant T* mat2Data [[buffer(1)]],
 | 
			
		||||
    device T* outputData [[buffer(2)]],
 | 
			
		||||
    constant T* biasData [[buffer(3)]],
 | 
			
		||||
    constant array<c10::metal::opmath_t<T>, 2>& alpha_beta [[buffer(4)]],
 | 
			
		||||
    constant array<ulong2, 4>& strides [[buffer(5)]],
 | 
			
		||||
    constant uint3& sizes [[buffer(6)]],
 | 
			
		||||
    uint2 tid [[thread_position_in_threadgroup]],
 | 
			
		||||
    uint2 thread_id [[thread_position_in_grid]]) {
 | 
			
		||||
  threadgroup T A_tile[TILE_DIM][TILE_DIM];
 | 
			
		||||
  threadgroup T B_tile[TILE_DIM][TILE_DIM];
 | 
			
		||||
 | 
			
		||||
  auto sum = matmul_inner<T>(
 | 
			
		||||
      mat1Data,
 | 
			
		||||
      mat2Data,
 | 
			
		||||
      reinterpret_cast<constant array<ulong2, 3>&>(strides),
 | 
			
		||||
      sizes,
 | 
			
		||||
      A_tile,
 | 
			
		||||
      B_tile,
 | 
			
		||||
      tid,
 | 
			
		||||
      thread_id);
 | 
			
		||||
  if (thread_id.y < sizes.x && thread_id.x < sizes.z) {
 | 
			
		||||
    auto bias =
 | 
			
		||||
        biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
 | 
			
		||||
    outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
 | 
			
		||||
        static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void naive_bmm(
 | 
			
		||||
    constant T* mat1Data [[buffer(0)]],
 | 
			
		||||
@ -644,15 +613,17 @@ kernel void applyPivots(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_MM_OPS(DTYPE)                                           \
 | 
			
		||||
  template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>(       \
 | 
			
		||||
      constant DTYPE * mat1Data [[buffer(0)]],                              \
 | 
			
		||||
      constant DTYPE * mat2Data [[buffer(1)]],                              \
 | 
			
		||||
      device DTYPE * outputData [[buffer(2)]],                              \
 | 
			
		||||
      constant array<ulong2, 3> & strides [[buffer(3)]],                    \
 | 
			
		||||
      constant uint3 & sizes [[buffer(4)]],                                 \
 | 
			
		||||
      uint2 tid [[thread_position_in_threadgroup]],                         \
 | 
			
		||||
      uint2 group_id [[threadgroup_position_in_grid]]);                     \
 | 
			
		||||
#define INSTANTIATE_NAIVE_MM(DTYPE)                                   \
 | 
			
		||||
  template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
 | 
			
		||||
      constant DTYPE * mat1Data [[buffer(0)]],                        \
 | 
			
		||||
      constant DTYPE * mat2Data [[buffer(1)]],                        \
 | 
			
		||||
      device DTYPE * outputData [[buffer(2)]],                        \
 | 
			
		||||
      constant array<ulong2, 3> & strides [[buffer(3)]],              \
 | 
			
		||||
      constant uint3 & sizes [[buffer(4)]],                           \
 | 
			
		||||
      uint2 tid [[thread_position_in_threadgroup]],                   \
 | 
			
		||||
      uint2 group_id [[threadgroup_position_in_grid]])
 | 
			
		||||
 | 
			
		||||
#define INSTANTIATE_NAIVE_BMM(DTYPE)                                        \
 | 
			
		||||
  template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm<DTYPE>( \
 | 
			
		||||
      constant DTYPE * mat1Data [[buffer(0)]],                              \
 | 
			
		||||
      constant DTYPE * mat2Data [[buffer(1)]],                              \
 | 
			
		||||
@ -660,26 +631,20 @@ kernel void applyPivots(
 | 
			
		||||
      constant array<ulong, 9> & strides [[buffer(3)]],                     \
 | 
			
		||||
      constant uint4 & sizes [[buffer(4)]],                                 \
 | 
			
		||||
      uint3 tid [[thread_position_in_threadgroup]],                         \
 | 
			
		||||
      uint3 group_id [[threadgroup_position_in_grid]]);                     \
 | 
			
		||||
  template [[host_name("addmm_" #DTYPE)]] kernel void addmm<DTYPE>(         \
 | 
			
		||||
      constant DTYPE * mat1Data [[buffer(0)]],                              \
 | 
			
		||||
      constant DTYPE * mat2Data [[buffer(1)]],                              \
 | 
			
		||||
      device DTYPE * outputData [[buffer(2)]],                              \
 | 
			
		||||
      constant DTYPE * biasData [[buffer(3)]],                              \
 | 
			
		||||
      constant array<c10::metal::opmath_t<DTYPE>, 2> &                      \
 | 
			
		||||
          alpha_beta [[buffer(4)]],                                         \
 | 
			
		||||
      constant array<ulong2, 4> & strides [[buffer(5)]],                    \
 | 
			
		||||
      constant uint3 & sizes [[buffer(6)]],                                 \
 | 
			
		||||
      uint2 tid [[thread_position_in_threadgroup]],                         \
 | 
			
		||||
      uint2 group_id [[threadgroup_position_in_grid]])
 | 
			
		||||
      uint3 group_id [[threadgroup_position_in_grid]])
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_MM_OPS(float);
 | 
			
		||||
INSTANTIATE_MM_OPS(half);
 | 
			
		||||
INSTANTIATE_MM_OPS(bfloat);
 | 
			
		||||
INSTANTIATE_NAIVE_MM(float);
 | 
			
		||||
INSTANTIATE_NAIVE_MM(half);
 | 
			
		||||
INSTANTIATE_NAIVE_MM(bfloat);
 | 
			
		||||
 | 
			
		||||
// Integral MM
 | 
			
		||||
INSTANTIATE_MM_OPS(long);
 | 
			
		||||
INSTANTIATE_MM_OPS(int);
 | 
			
		||||
INSTANTIATE_MM_OPS(short);
 | 
			
		||||
INSTANTIATE_MM_OPS(char);
 | 
			
		||||
INSTANTIATE_MM_OPS(uchar);
 | 
			
		||||
INSTANTIATE_NAIVE_MM(short);
 | 
			
		||||
INSTANTIATE_NAIVE_MM(int);
 | 
			
		||||
INSTANTIATE_NAIVE_MM(long);
 | 
			
		||||
INSTANTIATE_NAIVE_MM(char);
 | 
			
		||||
INSTANTIATE_NAIVE_MM(uchar);
 | 
			
		||||
INSTANTIATE_NAIVE_BMM(short);
 | 
			
		||||
INSTANTIATE_NAIVE_BMM(int);
 | 
			
		||||
INSTANTIATE_NAIVE_BMM(long);
 | 
			
		||||
INSTANTIATE_NAIVE_BMM(char);
 | 
			
		||||
INSTANTIATE_NAIVE_BMM(uchar);
 | 
			
		||||
 | 
			
		||||
@ -88,53 +88,6 @@ void max_pool_3d_input_iter(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, bool return_indices>
 | 
			
		||||
void max_pool_2d_input_iter(
 | 
			
		||||
    constant T* input,
 | 
			
		||||
    device T* output,
 | 
			
		||||
    device int64_t* indices,
 | 
			
		||||
    constant int32_t* input_sizes,
 | 
			
		||||
    constant int32_t* input_strides,
 | 
			
		||||
    thread int32_t (&pooling_dim_indices)[3],
 | 
			
		||||
    constant int32_t* kernel_size,
 | 
			
		||||
    constant int32_t* stride,
 | 
			
		||||
    constant int32_t* padding,
 | 
			
		||||
    constant int32_t* dilation) {
 | 
			
		||||
  auto bounds0 = get_input_iter_bounds<0>(
 | 
			
		||||
      input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
 | 
			
		||||
  auto bounds1 = get_input_iter_bounds<1>(
 | 
			
		||||
      input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
 | 
			
		||||
 | 
			
		||||
  auto d0 = dilation[0];
 | 
			
		||||
  auto d1 = dilation[1];
 | 
			
		||||
 | 
			
		||||
  T max_value = input
 | 
			
		||||
      [input_strides[0] * bounds0.start + input_strides[1] * bounds1.start];
 | 
			
		||||
  auto max_index = bounds0.start * input_sizes[1] + bounds1.start;
 | 
			
		||||
 | 
			
		||||
  for (auto i0 = bounds0.start; i0 < bounds0.end; i0 += d0) {
 | 
			
		||||
    auto offset0 = input_strides[0] * i0;
 | 
			
		||||
 | 
			
		||||
    for (auto i1 = bounds1.start; i1 < bounds1.end; i1 += d1) {
 | 
			
		||||
      auto offset1 = input_strides[1] * i1;
 | 
			
		||||
 | 
			
		||||
      auto input_value = input[offset0 + offset1];
 | 
			
		||||
      bool is_greater = input_value > max_value;
 | 
			
		||||
 | 
			
		||||
      max_value = is_greater ? input_value : max_value;
 | 
			
		||||
 | 
			
		||||
      if (return_indices) {
 | 
			
		||||
        auto input_index = i0 * input_sizes[1] + i1;
 | 
			
		||||
        max_index = is_greater ? input_index : max_index;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  *output = max_value;
 | 
			
		||||
  if (return_indices) {
 | 
			
		||||
    *indices = max_index;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct PoolOffsets {
 | 
			
		||||
  int32_t output;
 | 
			
		||||
  int32_t indices;
 | 
			
		||||
@ -259,7 +212,7 @@ kernel void max_pool(
 | 
			
		||||
  PoolOffsets offsets = find_pool_offsets(
 | 
			
		||||
      output_sizes,
 | 
			
		||||
      output_strides,
 | 
			
		||||
      return_indices ? indices_strides : nullptr,
 | 
			
		||||
      indices_strides,
 | 
			
		||||
      input_strides,
 | 
			
		||||
      pooling_dim_indices,
 | 
			
		||||
      dims,
 | 
			
		||||
@ -271,47 +224,18 @@ kernel void max_pool(
 | 
			
		||||
  indices += offsets.indices;
 | 
			
		||||
  input += offsets.input_leading;
 | 
			
		||||
 | 
			
		||||
  switch (pooling_dims) {
 | 
			
		||||
    case 2:
 | 
			
		||||
      if (return_indices) {
 | 
			
		||||
        return max_pool_2d_input_iter<T, /*return_indices=*/true>(
 | 
			
		||||
            input,
 | 
			
		||||
            output,
 | 
			
		||||
            indices,
 | 
			
		||||
            input_sizes + leading_dims,
 | 
			
		||||
            input_strides + leading_dims,
 | 
			
		||||
            pooling_dim_indices,
 | 
			
		||||
            kernel_size,
 | 
			
		||||
            stride,
 | 
			
		||||
            padding,
 | 
			
		||||
            dilation);
 | 
			
		||||
      } else {
 | 
			
		||||
        return max_pool_2d_input_iter<T, /*return_indices=*/false>(
 | 
			
		||||
            input,
 | 
			
		||||
            output,
 | 
			
		||||
            indices,
 | 
			
		||||
            input_sizes + leading_dims,
 | 
			
		||||
            input_strides + leading_dims,
 | 
			
		||||
            pooling_dim_indices,
 | 
			
		||||
            kernel_size,
 | 
			
		||||
            stride,
 | 
			
		||||
            padding,
 | 
			
		||||
            dilation);
 | 
			
		||||
      }
 | 
			
		||||
    case 3:
 | 
			
		||||
      return max_pool_3d_input_iter<T>(
 | 
			
		||||
          input,
 | 
			
		||||
          output,
 | 
			
		||||
          indices,
 | 
			
		||||
          input_sizes + leading_dims,
 | 
			
		||||
          input_strides + leading_dims,
 | 
			
		||||
          pooling_dim_indices,
 | 
			
		||||
          kernel_size,
 | 
			
		||||
          stride,
 | 
			
		||||
          padding,
 | 
			
		||||
          dilation,
 | 
			
		||||
          return_indices);
 | 
			
		||||
  }
 | 
			
		||||
  max_pool_3d_input_iter<T>(
 | 
			
		||||
      input,
 | 
			
		||||
      output,
 | 
			
		||||
      indices,
 | 
			
		||||
      input_sizes + leading_dims,
 | 
			
		||||
      input_strides + leading_dims,
 | 
			
		||||
      pooling_dim_indices,
 | 
			
		||||
      kernel_size,
 | 
			
		||||
      stride,
 | 
			
		||||
      padding,
 | 
			
		||||
      dilation,
 | 
			
		||||
      return_indices);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Finds the element in the grad input which corresponds to the index into the
 | 
			
		||||
 | 
			
		||||
@ -53,7 +53,6 @@ void binary_op_kernel(const std::string func_name,
 | 
			
		||||
                  .add_input(input)
 | 
			
		||||
                  .add_input(other)
 | 
			
		||||
                  .check_all_same_dtype(false)
 | 
			
		||||
                  .promote_inputs_to_common_dtype(true)
 | 
			
		||||
                  .build();
 | 
			
		||||
 | 
			
		||||
  lib.exec_binary_kernel(iter, func_name, alpha);
 | 
			
		||||
 | 
			
		||||
@ -48,11 +48,28 @@ typedef MPSGraphTensor* (^BinaryOpBlock)(BinaryOpCachedGraph*, MPSGraphTensor*,
 | 
			
		||||
#define BinaryOpFn(graph, primary, secondary) \
 | 
			
		||||
  MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
 | 
			
		||||
 | 
			
		||||
static inline Tensor legacy_complex_as_view(const Tensor& t) {
 | 
			
		||||
  // Convert non-complex types (and cdouble CPU scalars) to cfloat
 | 
			
		||||
  if (!isComplexType(t.scalar_type()) || t.scalar_type() == kComplexDouble) {
 | 
			
		||||
    return at::view_as_real(t.to(kMPS, kComplexFloat));
 | 
			
		||||
  }
 | 
			
		||||
  return at::view_as_real(t.dim() != 0 ? t : t.to(kMPS));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void binaryOpTensor(const Tensor& self,
 | 
			
		||||
                           const Tensor& other,
 | 
			
		||||
                           const Tensor& output_,
 | 
			
		||||
                           std::string op_name,
 | 
			
		||||
                           BinaryOpBlock binaryBlock) {
 | 
			
		||||
  TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) &&
 | 
			
		||||
                (self.scalar_type() == ScalarType::Long ||
 | 
			
		||||
                 (other.scalar_type() == ScalarType::Long &&
 | 
			
		||||
                  (self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
 | 
			
		||||
              "MPS: ",
 | 
			
		||||
              op_name,
 | 
			
		||||
              " op with int64 input is supported natively starting from macOS 13.2");
 | 
			
		||||
  TORCH_CHECK_TYPE(!isComplexType(self.scalar_type()) || mps::supportsComplex(),
 | 
			
		||||
                   "Complex types are supported starting from MacOS 14.0+");
 | 
			
		||||
  MPSStream* mpsStream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  const bool is_self_scalar = self.dim() == 0;
 | 
			
		||||
 | 
			
		||||
@ -51,6 +51,9 @@ inline void dot_check(const Tensor& self, const Tensor& other) {
 | 
			
		||||
} // namespace mps
 | 
			
		||||
 | 
			
		||||
Tensor dot_mps(const Tensor& self, const Tensor& other) {
 | 
			
		||||
  TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || self.scalar_type() != ScalarType::Long,
 | 
			
		||||
              "MPS: dot op doesn't support int64 input on MacOS13")
 | 
			
		||||
 | 
			
		||||
  using namespace mps;
 | 
			
		||||
  using CachedGraph = MPSBinaryCachedGraph;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -124,6 +124,7 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
 | 
			
		||||
                                    IntArrayRef dilation,
 | 
			
		||||
                                    int64_t groups,
 | 
			
		||||
                                    std::optional<IntArrayRef> input_shape) {
 | 
			
		||||
  const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
 | 
			
		||||
  const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
 | 
			
		||||
  Tensor input_t = input_t_;
 | 
			
		||||
  bool is3DConv = input_t.dim() == 5;
 | 
			
		||||
@ -131,6 +132,9 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
 | 
			
		||||
    input_t = input_t.contiguous();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer),
 | 
			
		||||
              "Conv3D is only supported on MPS for MacOS_13_2 or newer");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
 | 
			
		||||
 | 
			
		||||
  using namespace at::native::mps;
 | 
			
		||||
 | 
			
		||||
@ -60,6 +60,7 @@ static void copy_cast_mps(at::Tensor& dst,
 | 
			
		||||
        outputTensor = [mpsGraph castTensor:outputTensor toType:dstDType name:@"cast"];
 | 
			
		||||
      }
 | 
			
		||||
      if (needs_conj) {
 | 
			
		||||
        TORCH_CHECK(supportsComplex(), "MPS complex tensors conjugation needs MacOS14+");
 | 
			
		||||
        outputTensor = [mpsGraph conjugateWithTensor:outputTensor name:nil];
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
@ -274,7 +275,24 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
 | 
			
		||||
    // for GPU to GPU copies we only encode to stream's command buffer (no flushing)
 | 
			
		||||
    stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset, profile_id);
 | 
			
		||||
  } else {
 | 
			
		||||
    if (dst_byte_offset) {
 | 
			
		||||
    // Simulate cast to Complex on older MacOS by initializing real and imag parts
 | 
			
		||||
    if (dst_.is_complex() && !supportsComplex()) {
 | 
			
		||||
      if (!src.is_complex()) {
 | 
			
		||||
        at::real(dst_).copy_(src);
 | 
			
		||||
        at::imag(dst_).fill_(0);
 | 
			
		||||
      } else if (src.is_conj() || dst_.is_conj()) {
 | 
			
		||||
        // One cannot take view of conjugated tensor, but for some reason real and imag views are fine
 | 
			
		||||
        // Use this to implement a conjugation
 | 
			
		||||
        at::real(dst_).copy_(at::real(src));
 | 
			
		||||
        if (src.is_conj() != dst_.is_conj()) {
 | 
			
		||||
          at::imag(dst_).copy_(at::neg(at::imag(src)));
 | 
			
		||||
        } else {
 | 
			
		||||
          at::imag(dst_).copy_(at::imag(src));
 | 
			
		||||
        }
 | 
			
		||||
      } else {
 | 
			
		||||
        at::view_as_real(dst_).copy_(at::view_as_real(src));
 | 
			
		||||
      }
 | 
			
		||||
    } else if (dst_byte_offset) {
 | 
			
		||||
      auto maybeCastedSource =
 | 
			
		||||
          at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
 | 
			
		||||
      auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource);
 | 
			
		||||
 | 
			
		||||
@ -87,6 +87,7 @@ Tensor& random_mps_impl(Tensor& self,
 | 
			
		||||
          case kFloat:
 | 
			
		||||
            return MPSDataTypeFloat32;
 | 
			
		||||
          case kBFloat16: {
 | 
			
		||||
            checkSupportsBFloat16();
 | 
			
		||||
            return MPSDataTypeBFloat16;
 | 
			
		||||
          }
 | 
			
		||||
          default:
 | 
			
		||||
 | 
			
		||||
@ -88,6 +88,7 @@ using namespace mps;
 | 
			
		||||
 | 
			
		||||
// TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237
 | 
			
		||||
Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) {
 | 
			
		||||
  TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
 | 
			
		||||
  auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" +
 | 
			
		||||
      std::to_string(normalization) + ":" + std::to_string(onesided);
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
@ -128,6 +129,7 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
 | 
			
		||||
                         int64_t normalization,
 | 
			
		||||
                         int64_t last_dim_size,
 | 
			
		||||
                         Tensor& out) {
 | 
			
		||||
  TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
 | 
			
		||||
  auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
 | 
			
		||||
      std::to_string(normalization) + ":" + std::to_string(last_dim_size);
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
@ -153,6 +155,7 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) {
 | 
			
		||||
  TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
 | 
			
		||||
  auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
 | 
			
		||||
      std::to_string(normalization) + ":" + std::to_string(forward);
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
 | 
			
		||||
@ -127,6 +127,15 @@ Tensor grid_sampler_2d_mps(const Tensor& input,
 | 
			
		||||
                           int64_t interpolation_mode,
 | 
			
		||||
                           int64_t padding_mode,
 | 
			
		||||
                           bool align_corners) {
 | 
			
		||||
  if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS)) {
 | 
			
		||||
    TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.2. ",
 | 
			
		||||
                    "Falling back on CPU. This may have performance implications.");
 | 
			
		||||
 | 
			
		||||
    return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners)
 | 
			
		||||
        .clone()
 | 
			
		||||
        .to("mps");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto in_size = input.sizes();
 | 
			
		||||
  auto grid_size = grid.sizes();
 | 
			
		||||
  auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());
 | 
			
		||||
 | 
			
		||||
@ -108,12 +108,26 @@ static std::string getBitSizeString(const TensorBase& t) {
 | 
			
		||||
static void validateInputData(const TensorIteratorBase& iter,
 | 
			
		||||
                              IntArrayRef index_size,
 | 
			
		||||
                              IntArrayRef index_stride,
 | 
			
		||||
                              const std::string& op) {
 | 
			
		||||
                              const std::string& op,
 | 
			
		||||
                              bool accumulate) {
 | 
			
		||||
  using namespace mps;
 | 
			
		||||
 | 
			
		||||
  const auto num_indices = index_size.size();
 | 
			
		||||
  TORCH_CHECK(num_indices <= 16, "Current limit allows up to 16 indices to be used in MPS indexing kernels");
 | 
			
		||||
 | 
			
		||||
  AT_ASSERT(num_indices == index_stride.size());
 | 
			
		||||
  AT_ASSERT(static_cast<int>(num_indices) == iter.ntensors() - 2);
 | 
			
		||||
  const Tensor& inputTensor = iter.tensor(1);
 | 
			
		||||
  const auto scalar_type = inputTensor.scalar_type();
 | 
			
		||||
 | 
			
		||||
  if (accumulate) {
 | 
			
		||||
    // No atomic support for the complex dtypes
 | 
			
		||||
    TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type));
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) ||
 | 
			
		||||
                    scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf,
 | 
			
		||||
                getMPSTypeString(inputTensor) + std::string(" not supported for index.Tensor_out"));
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, const Tensor& mask) {
 | 
			
		||||
@ -144,7 +158,7 @@ static void dispatch_index_kernel(TensorIteratorBase& iter,
 | 
			
		||||
                                  IntArrayRef index_stride,
 | 
			
		||||
                                  const std::string& kernel_name,
 | 
			
		||||
                                  const bool serial = false) {
 | 
			
		||||
  validateInputData(iter, index_size, index_stride, "index.Tensor_out");
 | 
			
		||||
  validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
 | 
			
		||||
  if (iter.numel() == 0)
 | 
			
		||||
    return;
 | 
			
		||||
  if (!iter.can_use_32bit_indexing()) {
 | 
			
		||||
@ -186,7 +200,7 @@ static void dispatch_index_kernel(TensorIteratorBase& iter,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
 | 
			
		||||
  validateInputData(iter, index_size, index_stride, "index.Tensor_out");
 | 
			
		||||
  validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
 | 
			
		||||
  dispatch_index_kernel(
 | 
			
		||||
      iter, index_size, index_stride, fmt::format("index_select_{}", getBitSizeString(iter.tensor_base(0))));
 | 
			
		||||
}
 | 
			
		||||
@ -196,7 +210,7 @@ static void index_put_kernel_mps(TensorIterator& iter,
 | 
			
		||||
                                 IntArrayRef index_stride,
 | 
			
		||||
                                 bool accumulate) {
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
    validateInputData(iter, index_size, index_stride, "index_put_impl");
 | 
			
		||||
    validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate);
 | 
			
		||||
    if (accumulate) {
 | 
			
		||||
      dispatch_index_kernel(iter,
 | 
			
		||||
                            index_size,
 | 
			
		||||
@ -339,7 +353,14 @@ static Tensor& nonzero_out_native_mps(const Tensor& self, Tensor& out_) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
 | 
			
		||||
  if (self.is_complex()) {
 | 
			
		||||
  if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
 | 
			
		||||
    TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ",
 | 
			
		||||
                    "Falling back on CPU. This may have performance implications.");
 | 
			
		||||
    Tensor out_fallback = nonzero_fallback(self);
 | 
			
		||||
    at::native::resize_output(out_, out_fallback.sizes());
 | 
			
		||||
    out_.copy_(out_fallback);
 | 
			
		||||
    return out_;
 | 
			
		||||
  } else if (self.is_complex()) {
 | 
			
		||||
    TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ",
 | 
			
		||||
                    "Falling back on CPU. This may have performance implications.");
 | 
			
		||||
    Tensor out_fallback = nonzero_fallback(self);
 | 
			
		||||
@ -424,7 +445,11 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor nonzero_mps(const Tensor& self) {
 | 
			
		||||
  if (self.is_complex()) {
 | 
			
		||||
  if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
 | 
			
		||||
    TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ",
 | 
			
		||||
                    "Falling back on CPU. This may have performance implications.");
 | 
			
		||||
    return nonzero_fallback(self);
 | 
			
		||||
  } else if (self.is_complex()) {
 | 
			
		||||
    TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ",
 | 
			
		||||
                    "Falling back on CPU. This may have performance implications.");
 | 
			
		||||
    return nonzero_fallback(self);
 | 
			
		||||
 | 
			
		||||
@ -112,61 +112,6 @@ Tensor& do_metal_bmm(const Tensor& batch1, const Tensor& batch2, Tensor& output)
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& do_metal_addmm(const Tensor& self,
 | 
			
		||||
                       const Tensor& other,
 | 
			
		||||
                       Tensor& output,
 | 
			
		||||
                       const Scalar& alpha,
 | 
			
		||||
                       const Scalar& beta,
 | 
			
		||||
                       const Tensor& bias) {
 | 
			
		||||
  if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
 | 
			
		||||
    return do_metal_mm(self, other, output);
 | 
			
		||||
  }
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
  auto device = MPSDevice::getInstance()->device();
 | 
			
		||||
  auto matmulPSO = lib.getPipelineStateForFunc("addmm_" + mps::scalarToMetalTypeString(output));
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      getMPSProfiler().beginProfileKernel(matmulPSO, "addmm", {self, other});
 | 
			
		||||
      auto computeEncoder = stream->commandEncoder();
 | 
			
		||||
      [computeEncoder setComputePipelineState:matmulPSO];
 | 
			
		||||
      std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
 | 
			
		||||
                                       static_cast<uint32_t>(self.size(1)),
 | 
			
		||||
                                       static_cast<uint32_t>(output.size(1))};
 | 
			
		||||
      std::array<int64_t, 8> strides = {self.stride(0),
 | 
			
		||||
                                        self.stride(1),
 | 
			
		||||
                                        other.stride(0),
 | 
			
		||||
                                        other.stride(1),
 | 
			
		||||
                                        output.stride(0),
 | 
			
		||||
                                        output.stride(1),
 | 
			
		||||
                                        bias.stride(0),
 | 
			
		||||
                                        bias.stride(1)};
 | 
			
		||||
      union {
 | 
			
		||||
        std::array<int64_t, 2> i64;
 | 
			
		||||
        std::array<int32_t, 2> i32;
 | 
			
		||||
        std::array<float, 2> f32;
 | 
			
		||||
      } alpha_beta;
 | 
			
		||||
      if (output.scalar_type() == kLong) {
 | 
			
		||||
        alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
 | 
			
		||||
      } else if (c10::isIntegralType(output.scalar_type(), true)) {
 | 
			
		||||
        alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
 | 
			
		||||
      } else {
 | 
			
		||||
        TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
 | 
			
		||||
        alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
 | 
			
		||||
      }
 | 
			
		||||
      constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
 | 
			
		||||
      uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM;
 | 
			
		||||
      uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM;
 | 
			
		||||
 | 
			
		||||
      MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1);
 | 
			
		||||
      MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1);
 | 
			
		||||
      mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta.i64, strides, sizes);
 | 
			
		||||
      [computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
 | 
			
		||||
      getMPSProfiler().endProfileKernel(matmulPSO);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
 | 
			
		||||
                                                                    const Tensor& self,
 | 
			
		||||
                                                                    const Tensor& other) {
 | 
			
		||||
@ -699,6 +644,7 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(output.is_mps());
 | 
			
		||||
  TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
 | 
			
		||||
  TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input");
 | 
			
		||||
 | 
			
		||||
  TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
 | 
			
		||||
  checkAllSameGPU(__func__, args);
 | 
			
		||||
@ -725,10 +671,6 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
 | 
			
		||||
    return output;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (use_metal_mm(self, other, output)) {
 | 
			
		||||
    return do_metal_addmm(self, other, output, alpha, beta, *bias_);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool is_beta_non_zero = beta.toDouble() != 0.0;
 | 
			
		||||
 | 
			
		||||
  struct CachedGraph : public mps::MPSCachedGraph {
 | 
			
		||||
 | 
			
		||||
@ -297,13 +297,13 @@ static PoolSizes process_pool_sizes(const Tensor& input,
 | 
			
		||||
              pooling_dims,
 | 
			
		||||
              " ints");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == pooling_dims,
 | 
			
		||||
  TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
 | 
			
		||||
              op_name,
 | 
			
		||||
              ": stride must either be omitted, a single int, or a tuple of ",
 | 
			
		||||
              pooling_dims,
 | 
			
		||||
              " ints");
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(padding.size() == 1 || padding.size() == pooling_dims,
 | 
			
		||||
  TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
 | 
			
		||||
              op_name,
 | 
			
		||||
              ": padding must either be a single int, or a tuple of ",
 | 
			
		||||
              pooling_dims,
 | 
			
		||||
@ -333,22 +333,6 @@ static PoolSizes process_pool_sizes(const Tensor& input,
 | 
			
		||||
                ": pad should be at most half of effective kernel size");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (pooling_dims == 2) {
 | 
			
		||||
    const auto memory_format = input.suggest_memory_format();
 | 
			
		||||
    bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
 | 
			
		||||
    if (memory_format == at::MemoryFormat::ChannelsLast) {
 | 
			
		||||
      // Expect tensor in NHWC format and allow 0-dim only for N.
 | 
			
		||||
      TORCH_CHECK((dims == 4 && valid_dims && input.size(3) != 0),
 | 
			
		||||
                  "Expected 4D (batch mode) tensor expected for input with channels_last layout"
 | 
			
		||||
                  " with optional 0 dim batch size for input, but got: ",
 | 
			
		||||
                  input.sizes());
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK((dims == 3 && input.size(0) != 0 && valid_dims) || (dims == 4 && valid_dims && input.size(3) != 0),
 | 
			
		||||
                  "Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
 | 
			
		||||
                  input.sizes());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (const auto dim : c10::irange(static_cast<int>(leading_dims == 2), dims)) {
 | 
			
		||||
    TORCH_CHECK(input.size(dim) > 0, op_name, ": Expected input's non-batch dimensions to have positive length");
 | 
			
		||||
  }
 | 
			
		||||
@ -802,16 +786,6 @@ static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
 | 
			
		||||
 | 
			
		||||
} // namespace mps
 | 
			
		||||
 | 
			
		||||
// TODO: The MPS graph impl can sometimes give significantly better performance
 | 
			
		||||
// than the Metal impl for cases where the stride is 1 in all dimensions. There
 | 
			
		||||
// may be a code path in the graph kernel that specifically optimizes for that
 | 
			
		||||
// case. We should look into implementing a specialized case in Metal so we can
 | 
			
		||||
// avoid using the graph impl.
 | 
			
		||||
static bool use_graph_for_max_pool2d(IntArrayRef kernel_size, IntArrayRef stride_) {
 | 
			
		||||
  IntArrayRef stride = stride_.empty() ? kernel_size : stride_;
 | 
			
		||||
  return (stride[0] == 1) && (stride.size() == 1 || stride[1] == 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor mps_max_pool2d(const Tensor& input,
 | 
			
		||||
                      IntArrayRef kernel_size,
 | 
			
		||||
                      IntArrayRef stride,
 | 
			
		||||
@ -819,37 +793,24 @@ Tensor mps_max_pool2d(const Tensor& input,
 | 
			
		||||
                      IntArrayRef dilation,
 | 
			
		||||
                      bool ceil_mode) {
 | 
			
		||||
  Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous);
 | 
			
		||||
  bool use_graph = use_graph_for_max_pool2d(kernel_size, stride);
 | 
			
		||||
  if (use_graph) {
 | 
			
		||||
    mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
 | 
			
		||||
      MPSGraph* mpsGraph = cachedGraph.graph();
 | 
			
		||||
      return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
 | 
			
		||||
    };
 | 
			
		||||
    mps::pool2d_template(input,
 | 
			
		||||
                         output,
 | 
			
		||||
                         std::nullopt,
 | 
			
		||||
                         std::nullopt,
 | 
			
		||||
                         kernel_size,
 | 
			
		||||
                         stride,
 | 
			
		||||
                         padding,
 | 
			
		||||
                         dilation,
 | 
			
		||||
                         ceil_mode,
 | 
			
		||||
                         false,
 | 
			
		||||
                         std::nullopt,
 | 
			
		||||
                         pooling_op_block,
 | 
			
		||||
                         "max_pool2d");
 | 
			
		||||
  } else {
 | 
			
		||||
    mps::max_pool_with_indices_out_mps_template(output,
 | 
			
		||||
                                                std::nullopt,
 | 
			
		||||
                                                input,
 | 
			
		||||
                                                kernel_size,
 | 
			
		||||
                                                stride,
 | 
			
		||||
                                                padding,
 | 
			
		||||
                                                dilation,
 | 
			
		||||
                                                ceil_mode,
 | 
			
		||||
                                                /*pooling_dims=*/2,
 | 
			
		||||
                                                "max_pool2d");
 | 
			
		||||
  }
 | 
			
		||||
  mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
 | 
			
		||||
    MPSGraph* mpsGraph = cachedGraph.graph();
 | 
			
		||||
    return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
 | 
			
		||||
  };
 | 
			
		||||
  mps::pool2d_template(input,
 | 
			
		||||
                       output,
 | 
			
		||||
                       std::nullopt,
 | 
			
		||||
                       std::nullopt,
 | 
			
		||||
                       kernel_size,
 | 
			
		||||
                       stride,
 | 
			
		||||
                       padding,
 | 
			
		||||
                       dilation,
 | 
			
		||||
                       ceil_mode,
 | 
			
		||||
                       false,
 | 
			
		||||
                       std::nullopt,
 | 
			
		||||
                       pooling_op_block,
 | 
			
		||||
                       "max_pool2d");
 | 
			
		||||
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -894,45 +855,32 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)
 | 
			
		||||
 bool ceil_mode,
 | 
			
		||||
 const Tensor& output,
 | 
			
		||||
 const Tensor& indices) {
 | 
			
		||||
  bool use_graph = use_graph_for_max_pool2d(kernel_size, stride);
 | 
			
		||||
  if (use_graph) {
 | 
			
		||||
    auto indices_memory_format = indices.suggest_memory_format();
 | 
			
		||||
  auto indices_memory_format = indices.suggest_memory_format();
 | 
			
		||||
 | 
			
		||||
    mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
 | 
			
		||||
      MPSGraph* mpsGraph = cachedGraph.graph();
 | 
			
		||||
      NSArray<MPSGraphTensor*>* poolOutputs =
 | 
			
		||||
          [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
 | 
			
		||||
      cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long);
 | 
			
		||||
      return poolOutputs[0];
 | 
			
		||||
    };
 | 
			
		||||
    mps::pool2d_template(input,
 | 
			
		||||
                         output,
 | 
			
		||||
                         indices,
 | 
			
		||||
                         std::nullopt,
 | 
			
		||||
                         kernel_size,
 | 
			
		||||
                         stride,
 | 
			
		||||
                         padding,
 | 
			
		||||
                         dilation,
 | 
			
		||||
                         ceil_mode,
 | 
			
		||||
                         false,
 | 
			
		||||
                         std::nullopt,
 | 
			
		||||
                         pooling_op_block,
 | 
			
		||||
                         "max_pool2d_indices");
 | 
			
		||||
    if (indices_memory_format == MemoryFormat::ChannelsLast) {
 | 
			
		||||
      const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
 | 
			
		||||
    }
 | 
			
		||||
  mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
 | 
			
		||||
    MPSGraph* mpsGraph = cachedGraph.graph();
 | 
			
		||||
    NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor
 | 
			
		||||
                                                                                     descriptor:desc
 | 
			
		||||
                                                                                           name:nil];
 | 
			
		||||
    cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long);
 | 
			
		||||
    return poolOutputs[0];
 | 
			
		||||
  };
 | 
			
		||||
  mps::pool2d_template(input,
 | 
			
		||||
                       output,
 | 
			
		||||
                       indices,
 | 
			
		||||
                       std::nullopt,
 | 
			
		||||
                       kernel_size,
 | 
			
		||||
                       stride,
 | 
			
		||||
                       padding,
 | 
			
		||||
                       dilation,
 | 
			
		||||
                       ceil_mode,
 | 
			
		||||
                       false,
 | 
			
		||||
                       std::nullopt,
 | 
			
		||||
                       pooling_op_block,
 | 
			
		||||
                       "max_pool2d_indices");
 | 
			
		||||
 | 
			
		||||
  } else {
 | 
			
		||||
    mps::max_pool_with_indices_out_mps_template(output,
 | 
			
		||||
                                                indices,
 | 
			
		||||
                                                input,
 | 
			
		||||
                                                kernel_size,
 | 
			
		||||
                                                stride,
 | 
			
		||||
                                                padding,
 | 
			
		||||
                                                dilation,
 | 
			
		||||
                                                ceil_mode,
 | 
			
		||||
                                                /*pooling_dims=*/2,
 | 
			
		||||
                                                "max_pool2d");
 | 
			
		||||
  if (indices_memory_format == MemoryFormat::ChannelsLast) {
 | 
			
		||||
    const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -152,6 +152,8 @@ static void reduction_out_mps(const Tensor& input_t,
 | 
			
		||||
                              const Tensor& output_t,
 | 
			
		||||
                              MPSReductionType reduction_type,
 | 
			
		||||
                              const std::string& func_name) {
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name);
 | 
			
		||||
  // NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor
 | 
			
		||||
  bool canSqueezeLastDim = true;
 | 
			
		||||
  IntArrayRef input_shape = input_t.sizes();
 | 
			
		||||
@ -234,10 +236,12 @@ static void reduction_out_mps(const Tensor& input_t,
 | 
			
		||||
      MPSGraphTensor* castInputTensor = inputTensor;
 | 
			
		||||
      MPSDataType inputCastType = MPSDataTypeInvalid;
 | 
			
		||||
      if (dtype.has_value() &&
 | 
			
		||||
          (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || dtype.value() == kLong)) {
 | 
			
		||||
          (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt ||
 | 
			
		||||
           (dtype.value() == kLong && macOS13_3_plus))) {
 | 
			
		||||
        inputCastType = getMPSDataType(dtype.value());
 | 
			
		||||
      } else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
 | 
			
		||||
                 inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && inputScalarType != kLong) {
 | 
			
		||||
                 inputScalarType != kComplexFloat && inputScalarType != kComplexHalf &&
 | 
			
		||||
                 (inputScalarType != kLong || !macOS13_3_plus)) {
 | 
			
		||||
        inputCastType = getMPSDataType(kFloat);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
@ -611,6 +615,9 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, nanmedian ? "nanmedian" : "median");
 | 
			
		||||
 | 
			
		||||
  IntArrayRef input_shape = input_t.sizes();
 | 
			
		||||
  int64_t num_in_elements = c10::multiply_integers(input_shape);
 | 
			
		||||
 | 
			
		||||
@ -627,7 +634,8 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
 | 
			
		||||
  auto medianCachedGraph =
 | 
			
		||||
      LookUpOrCreateCachedGraph<MedianCachedGraph>(medianKey, [&](auto mpsGraph, auto newCachedGraph) {
 | 
			
		||||
        MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
 | 
			
		||||
        MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
 | 
			
		||||
        MPSGraphTensor* castInputTensor =
 | 
			
		||||
            castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 | 
			
		||||
 | 
			
		||||
        MPSGraphTensor* reshapedTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil];
 | 
			
		||||
 | 
			
		||||
@ -685,6 +693,9 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) {
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max");
 | 
			
		||||
 | 
			
		||||
  using CachedGraph = MPSUnaryCachedGraph;
 | 
			
		||||
 | 
			
		||||
  IntArrayRef input_shape = input_t.sizes();
 | 
			
		||||
@ -702,7 +713,8 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction
 | 
			
		||||
      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
 | 
			
		||||
 | 
			
		||||
      MPSGraphTensor* castOutputTensor = nil;
 | 
			
		||||
      MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
 | 
			
		||||
      MPSGraphTensor* castInputTensor =
 | 
			
		||||
          castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 | 
			
		||||
 | 
			
		||||
      NSArray<NSNumber*>* axes = getTensorAxes(input_t);
 | 
			
		||||
      if (reduction_type == MPSReductionType::MAX) {
 | 
			
		||||
@ -737,6 +749,9 @@ static void min_max_out_mps(const Tensor& input_t,
 | 
			
		||||
                            const Tensor& indices_t,
 | 
			
		||||
                            MPSReductionType reduction_type,
 | 
			
		||||
                            const std::string& func_name) {
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out");
 | 
			
		||||
 | 
			
		||||
  if (output_t.numel() == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
@ -774,7 +789,8 @@ static void min_max_out_mps(const Tensor& input_t,
 | 
			
		||||
    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
 | 
			
		||||
      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
 | 
			
		||||
      MPSGraphTensor* outputTensor = nil;
 | 
			
		||||
      MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
 | 
			
		||||
      MPSGraphTensor* castInputTensor =
 | 
			
		||||
          castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 | 
			
		||||
 | 
			
		||||
      if (reduction_type == MPSReductionType::MAX) {
 | 
			
		||||
        outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
 | 
			
		||||
@ -880,6 +896,9 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
 | 
			
		||||
                                  const std::string& func_name) {
 | 
			
		||||
  using CachedGraph = MPSUnaryCachedGraph;
 | 
			
		||||
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out");
 | 
			
		||||
 | 
			
		||||
  int64_t dim_ = -1;
 | 
			
		||||
 | 
			
		||||
  if (dim.has_value()) {
 | 
			
		||||
@ -934,7 +953,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
 | 
			
		||||
 | 
			
		||||
      MPSGraphTensor* castInputTensor = inputTensor;
 | 
			
		||||
      if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
 | 
			
		||||
          inputScalarType != kLong) {
 | 
			
		||||
          (inputScalarType != kLong || !macOS13_3_plus)) {
 | 
			
		||||
        castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat);
 | 
			
		||||
      }
 | 
			
		||||
      if (reduction_type == MPSReductionType::MAX) {
 | 
			
		||||
@ -1263,6 +1282,9 @@ static void all_any_common_impl_mps(const Tensor& input_t,
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, op_name);
 | 
			
		||||
 | 
			
		||||
  int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
 | 
			
		||||
  native::zero_numel_check_dims(input_t, dim_, op_name.c_str());
 | 
			
		||||
 | 
			
		||||
@ -1281,7 +1303,7 @@ static void all_any_common_impl_mps(const Tensor& input_t,
 | 
			
		||||
    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
 | 
			
		||||
      auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
 | 
			
		||||
 | 
			
		||||
      auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
 | 
			
		||||
      auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 | 
			
		||||
      // reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4
 | 
			
		||||
      // See https://github.com/pytorch/pytorch/issues/95538
 | 
			
		||||
      MPSGraphTensor* outputTensor = nil;
 | 
			
		||||
@ -1347,11 +1369,14 @@ TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out");
 | 
			
		||||
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
    std::string key = std::string("any_all_out_mps:") + getTensorsStringKey(input_t);
 | 
			
		||||
    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
 | 
			
		||||
      auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
 | 
			
		||||
      auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
 | 
			
		||||
      auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 | 
			
		||||
      // reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
 | 
			
		||||
      // See https://github.com/pytorch/pytorch/issues/95538
 | 
			
		||||
      if (input_t.dim() > 4) {
 | 
			
		||||
@ -1395,11 +1420,14 @@ TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out");
 | 
			
		||||
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
    std::string key = std::string("all_all_out_mps:") + getTensorsStringKey(input_t);
 | 
			
		||||
    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
 | 
			
		||||
      auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
 | 
			
		||||
      auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
 | 
			
		||||
      auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 | 
			
		||||
      // reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
 | 
			
		||||
      // See https://github.com/pytorch/pytorch/issues/95538
 | 
			
		||||
      if (input_t.ndimension() > 4) {
 | 
			
		||||
@ -1484,6 +1512,9 @@ static void median_out_mps_common(const Tensor& input_t,
 | 
			
		||||
                                  Tensor& indices,
 | 
			
		||||
                                  const std::string& func_name,
 | 
			
		||||
                                  bool nanmedian) {
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
 | 
			
		||||
 | 
			
		||||
  int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
 | 
			
		||||
  native::zero_numel_check_dims(input_t, dim_, "max()");
 | 
			
		||||
 | 
			
		||||
@ -1554,7 +1585,8 @@ static void median_out_mps_common(const Tensor& input_t,
 | 
			
		||||
        getTensorsStringKey(indices);
 | 
			
		||||
    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
 | 
			
		||||
      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
 | 
			
		||||
      MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
 | 
			
		||||
      MPSGraphTensor* castInputTensor =
 | 
			
		||||
          castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 | 
			
		||||
 | 
			
		||||
      MPSGraphTensor* effectiveLengthTensor = nil;
 | 
			
		||||
      if (nanmedian) {
 | 
			
		||||
 | 
			
		||||
@ -129,8 +129,16 @@ void computeRepeatIndices(const index_t* repeat_ptr,
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
 | 
			
		||||
Tensor repeat_interleave_mps(const Tensor& repeat_, std::optional<int64_t> output_size) {
 | 
			
		||||
  Tensor output;
 | 
			
		||||
  Tensor repeat = repeat_;
 | 
			
		||||
  if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
 | 
			
		||||
    // #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
 | 
			
		||||
    // which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
 | 
			
		||||
    TORCH_WARN_ONCE(
 | 
			
		||||
        "MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
 | 
			
		||||
    repeat = repeat.to(kInt);
 | 
			
		||||
  }
 | 
			
		||||
  AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
 | 
			
		||||
    output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,125 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
 | 
			
		||||
#include <ATen/native/mps/ScanKernel_metallib.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// Generic scan implementation that handles both simple scans and scans with indices
 | 
			
		||||
static void scan_mps_impl(const Tensor& self,
 | 
			
		||||
                          const std::vector<Tensor>& outputs,
 | 
			
		||||
                          int64_t dim,
 | 
			
		||||
                          const std::string& op_name) {
 | 
			
		||||
  if (outputs[0].numel() == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const int64_t ndim = self.dim();
 | 
			
		||||
  const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim);
 | 
			
		||||
 | 
			
		||||
  // Calculate dimensions for scan operation
 | 
			
		||||
  int64_t row_size = self.size(wrapped_dim);
 | 
			
		||||
  auto sizes = self.sizes();
 | 
			
		||||
 | 
			
		||||
  bool is_innermost = (wrapped_dim == ndim - 1);
 | 
			
		||||
 | 
			
		||||
  // Check if all tensors are contiguous
 | 
			
		||||
  bool is_contiguous = self.is_contiguous();
 | 
			
		||||
  for (const auto& output : outputs) {
 | 
			
		||||
    is_contiguous = is_contiguous && output.is_contiguous();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  uint32_t num_rows, num_orows, num_irows, num_threads;
 | 
			
		||||
 | 
			
		||||
  if (is_innermost) {
 | 
			
		||||
    // Treat all outer dimensions as a single dimension
 | 
			
		||||
    num_rows = self.numel() / row_size;
 | 
			
		||||
    num_threads = num_rows;
 | 
			
		||||
  } else {
 | 
			
		||||
    // Treat all outer dimensions (i.e. dim_ < dim) as one
 | 
			
		||||
    num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + wrapped_dim);
 | 
			
		||||
    // Treat all inner dimensions (i.e. dim > dimension) as one
 | 
			
		||||
    num_irows = c10::multiply_integers(sizes.begin() + wrapped_dim + 1, sizes.end());
 | 
			
		||||
    num_threads = num_orows * num_irows;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  MPSStream* mpsStream = getCurrentMPSStream();
 | 
			
		||||
  dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
 | 
			
		||||
 | 
			
		||||
      // Choose kernel based on contiguity and dimension
 | 
			
		||||
      std::string kernel_name;
 | 
			
		||||
      if (is_contiguous) {
 | 
			
		||||
        kernel_name =
 | 
			
		||||
            op_name + "_contiguous_" + (is_innermost ? "innermost_" : "outer_") + scalarToMetalTypeString(self);
 | 
			
		||||
      } else {
 | 
			
		||||
        kernel_name = op_name + "_strided_" + scalarToMetalTypeString(self);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      id<MTLComputePipelineState> scanPSO = lib.getPipelineStateForFunc(kernel_name);
 | 
			
		||||
 | 
			
		||||
      // this function call is a no-op if MPS Profiler is not enabled
 | 
			
		||||
      getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() {
 | 
			
		||||
        std::vector<Tensor> all_tensors = {self};
 | 
			
		||||
        all_tensors.insert(all_tensors.end(), outputs.begin(), outputs.end());
 | 
			
		||||
        return all_tensors;
 | 
			
		||||
      }());
 | 
			
		||||
 | 
			
		||||
      [computeEncoder setComputePipelineState:scanPSO];
 | 
			
		||||
 | 
			
		||||
      // Set input tensor
 | 
			
		||||
      mtl_setBuffer(computeEncoder, self, 0);
 | 
			
		||||
 | 
			
		||||
      // Set output tensors
 | 
			
		||||
      for (size_t i = 0; i < outputs.size(); ++i) {
 | 
			
		||||
        mtl_setBuffer(computeEncoder, outputs[i], i + 1);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if (is_contiguous) {
 | 
			
		||||
        // Contiguous kernels
 | 
			
		||||
        if (is_innermost) {
 | 
			
		||||
          if (outputs.size() == 1) {
 | 
			
		||||
            // Simple scan
 | 
			
		||||
            mtl_setArgs<2>(computeEncoder, num_rows, static_cast<uint32_t>(row_size));
 | 
			
		||||
          } else {
 | 
			
		||||
            // Scan with indices
 | 
			
		||||
            mtl_setArgs<3>(computeEncoder, num_rows, static_cast<uint32_t>(row_size));
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
          if (outputs.size() == 1) {
 | 
			
		||||
            // Simple scan
 | 
			
		||||
            mtl_setArgs<2>(computeEncoder, num_orows, num_irows, static_cast<uint32_t>(row_size));
 | 
			
		||||
          } else {
 | 
			
		||||
            // Scan with indices
 | 
			
		||||
            mtl_setArgs<3>(computeEncoder, num_orows, num_irows, static_cast<uint32_t>(row_size));
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      } else {
 | 
			
		||||
        // Strided kernels - pass full tensor information
 | 
			
		||||
        if (outputs.size() == 1) {
 | 
			
		||||
          // Simple scan
 | 
			
		||||
          mtl_setArgs<2>(computeEncoder,
 | 
			
		||||
                         self.sizes(),
 | 
			
		||||
                         self.strides(),
 | 
			
		||||
                         outputs[0].strides(),
 | 
			
		||||
                         static_cast<uint32_t>(self.ndimension()),
 | 
			
		||||
                         static_cast<uint32_t>(wrapped_dim));
 | 
			
		||||
        } else {
 | 
			
		||||
          // Scan with indices
 | 
			
		||||
          mtl_setArgs<3>(computeEncoder,
 | 
			
		||||
                         self.sizes(),
 | 
			
		||||
                         self.strides(),
 | 
			
		||||
                         outputs[0].strides(),
 | 
			
		||||
                         outputs[1].strides(),
 | 
			
		||||
                         static_cast<uint32_t>(self.ndimension()),
 | 
			
		||||
                         static_cast<uint32_t>(wrapped_dim));
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      mtl_dispatch1DJob(computeEncoder, scanPSO, num_threads);
 | 
			
		||||
 | 
			
		||||
      getMPSProfiler().endProfileKernel(scanPSO);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Utility function to get 2D grid dimensions for dispatch
 | 
			
		||||
static std::pair<uint32_t, uint32_t> get_2d_grid_dims(const IntArrayRef& shape, const int64_t dim) {
 | 
			
		||||
  size_t grid_x = 1;
 | 
			
		||||
@ -256,11 +375,19 @@ static void scan_with_indices_mps_impl(const Tensor& self,
 | 
			
		||||
} // namespace mps
 | 
			
		||||
 | 
			
		||||
void cummax_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
 | 
			
		||||
  mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax");
 | 
			
		||||
  if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
 | 
			
		||||
    mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax");
 | 
			
		||||
  } else {
 | 
			
		||||
    mps::scan_mps_impl(self, {values, indices}, dim, "cummax");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
 | 
			
		||||
  mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin");
 | 
			
		||||
  if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
 | 
			
		||||
    mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin");
 | 
			
		||||
  } else {
 | 
			
		||||
    mps::scan_mps_impl(self, {values, indices}, dim, "cummin");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) {
 | 
			
		||||
@ -275,7 +402,11 @@ Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) {
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp");
 | 
			
		||||
  if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
 | 
			
		||||
    mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp");
 | 
			
		||||
  } else {
 | 
			
		||||
    mps::scan_mps_impl(self, {result}, wrap_dim, "logcumsumexp");
 | 
			
		||||
  }
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,9 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
 | 
			
		||||
 const Tensor& indices) {
 | 
			
		||||
  using namespace mps;
 | 
			
		||||
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out");
 | 
			
		||||
 | 
			
		||||
  if (self.numel() == 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
@ -52,7 +55,8 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
 | 
			
		||||
    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
 | 
			
		||||
      newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
 | 
			
		||||
 | 
			
		||||
      MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self);
 | 
			
		||||
      MPSGraphTensor* castInputTensor =
 | 
			
		||||
          castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
 | 
			
		||||
      MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
 | 
			
		||||
                                                         axis:(NSInteger)dim
 | 
			
		||||
                                                   descending:(BOOL)descending
 | 
			
		||||
 | 
			
		||||
@ -297,6 +297,9 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
 | 
			
		||||
 | 
			
		||||
  const auto common_type = at::result_type(elements, test_elements);
 | 
			
		||||
  TORCH_CHECK(elements.is_mps() && test_elements.is_mps());
 | 
			
		||||
  TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || supportedFloatingType(common_type),
 | 
			
		||||
              "isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ",
 | 
			
		||||
              common_type);
 | 
			
		||||
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
    std::string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert);
 | 
			
		||||
 | 
			
		||||
@ -208,12 +208,28 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
 | 
			
		||||
  mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
    auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
 | 
			
		||||
    auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
 | 
			
		||||
    return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
 | 
			
		||||
  });
 | 
			
		||||
  return output;
 | 
			
		||||
  if (mps::supportsComplex()) {
 | 
			
		||||
    mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
      auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
 | 
			
		||||
      auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
 | 
			
		||||
      return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
 | 
			
		||||
    });
 | 
			
		||||
    return output;
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(!self.is_complex(), "MPS does not support angle with complex input on macOS13")
 | 
			
		||||
    mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
      // On macOS 13 with non-complex input, realPartOfTensor and imaginaryPartOfTensor are
 | 
			
		||||
      // not available, and NaN is not propagated correctly:
 | 
			
		||||
      auto imagPart = [mpsGraph constantWithScalar:0.0 shape:inputTensor.shape dataType:inputTensor.dataType];
 | 
			
		||||
      auto result = [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:inputTensor name:nil];
 | 
			
		||||
      auto nanMask = [mpsGraph isNaNWithTensor:inputTensor name:nil];
 | 
			
		||||
      return [mpsGraph selectWithPredicateTensor:nanMask
 | 
			
		||||
                             truePredicateTensor:inputTensor
 | 
			
		||||
                            falsePredicateTensor:result
 | 
			
		||||
                                            name:nil];
 | 
			
		||||
    });
 | 
			
		||||
    return output;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor angle_mps(const Tensor& self) {
 | 
			
		||||
@ -346,6 +362,7 @@ static void cumulative_op_impl(const Tensor& self,
 | 
			
		||||
                               const Tensor& result,
 | 
			
		||||
                               MPSCumulativeOpType cumulativeOpType,
 | 
			
		||||
                               const std::string& op_name) {
 | 
			
		||||
  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
 | 
			
		||||
  auto nDims = self.dim();
 | 
			
		||||
  auto wrapped_dim = maybe_wrap_dim(dim, nDims);
 | 
			
		||||
  TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
 | 
			
		||||
@ -364,6 +381,11 @@ static void cumulative_op_impl(const Tensor& self,
 | 
			
		||||
  bool castInputData = (isIntegralType(input.scalar_type(), true) && input.scalar_type() != ScalarType::Int &&
 | 
			
		||||
                        input.scalar_type() != ScalarType::Long);
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
 | 
			
		||||
              "MPS does not support ",
 | 
			
		||||
              op_name,
 | 
			
		||||
              " op with int64 input. Support has been added in macOS 13.3");
 | 
			
		||||
 | 
			
		||||
  mps::unary_op(
 | 
			
		||||
      input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
        if (castInputData) {
 | 
			
		||||
@ -418,9 +440,17 @@ TORCH_IMPL_FUNC(sgn_out_mps)(const Tensor& self, const Tensor& output) {
 | 
			
		||||
 | 
			
		||||
Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) {
 | 
			
		||||
  TORCH_CHECK(self.is_complex());
 | 
			
		||||
  mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
    return [mpsGraph conjugateWithTensor:inputTensor name:nil];
 | 
			
		||||
  });
 | 
			
		||||
  if (!mps::supportsComplex()) {
 | 
			
		||||
    if (!result.is_same_size(self)) {
 | 
			
		||||
      result.resize_(self.sizes());
 | 
			
		||||
    }
 | 
			
		||||
    at::real(result).copy_(at::real(self));
 | 
			
		||||
    at::imag(result).copy_(at::neg(at::imag(self)));
 | 
			
		||||
  } else {
 | 
			
		||||
    mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
 | 
			
		||||
      return [mpsGraph conjugateWithTensor:inputTensor name:nil];
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7423,7 +7423,6 @@
 | 
			
		||||
  dispatch:
 | 
			
		||||
    SparseCPU: _coalesce_sparse_cpu
 | 
			
		||||
    SparseCUDA: _coalesce_sparse_cuda
 | 
			
		||||
    SparseMPS: _coalesce_sparse_mps
 | 
			
		||||
  autogen: _coalesce.out
 | 
			
		||||
 | 
			
		||||
- func: is_coalesced(Tensor self) -> bool
 | 
			
		||||
@ -15013,7 +15012,6 @@
 | 
			
		||||
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
 | 
			
		||||
    NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda
 | 
			
		||||
  tags: nondeterministic_seeded
 | 
			
		||||
 | 
			
		||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
 | 
			
		||||
@ -15046,11 +15044,6 @@
 | 
			
		||||
    CUDA: _cudnn_attention_forward
 | 
			
		||||
  tags: nondeterministic_seeded
 | 
			
		||||
 | 
			
		||||
- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CUDA: _cudnn_attention_backward
 | 
			
		||||
  tags: nondeterministic_seeded
 | 
			
		||||
 | 
			
		||||
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
 | 
			
		||||
  variants: function
 | 
			
		||||
  dispatch:
 | 
			
		||||
 | 
			
		||||
@ -349,63 +349,6 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda(
 | 
			
		||||
  return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda(
 | 
			
		||||
    const Tensor& grad_out,
 | 
			
		||||
    const Tensor& query,
 | 
			
		||||
    const Tensor& key,
 | 
			
		||||
    const Tensor& value,
 | 
			
		||||
    const Tensor& out,
 | 
			
		||||
    const Tensor& logsumexp,
 | 
			
		||||
    const Tensor& philox_seed,
 | 
			
		||||
    const Tensor& philox_offset,
 | 
			
		||||
    const Tensor& attn_bias,
 | 
			
		||||
    const Tensor& cum_seq_q,
 | 
			
		||||
    const Tensor& cum_seq_k,
 | 
			
		||||
    const int64_t max_q,
 | 
			
		||||
    const int64_t max_k,
 | 
			
		||||
    double dropout_p,
 | 
			
		||||
    bool is_causal,
 | 
			
		||||
    std::optional<double> scale) {
 | 
			
		||||
  if (!grad_out.defined()) {
 | 
			
		||||
    return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
 | 
			
		||||
  }
 | 
			
		||||
  auto [
 | 
			
		||||
      grad_out_buffer_reshaped,
 | 
			
		||||
      query_buffer_reshaped,
 | 
			
		||||
      key_buffer_reshaped,
 | 
			
		||||
      value_buffer_reshaped,
 | 
			
		||||
      output_buffer_reshaped] =
 | 
			
		||||
      preprocessing::sdpa_nested_preprocessing_backward(
 | 
			
		||||
          grad_out,
 | 
			
		||||
          query,
 | 
			
		||||
          key,
 | 
			
		||||
          value,
 | 
			
		||||
          out,
 | 
			
		||||
          cum_seq_q,
 | 
			
		||||
          cum_seq_k,
 | 
			
		||||
          max_q,
 | 
			
		||||
          max_k);
 | 
			
		||||
 | 
			
		||||
  auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped,
 | 
			
		||||
                                                    query_buffer_reshaped,
 | 
			
		||||
                                                    key_buffer_reshaped,
 | 
			
		||||
                                                    value_buffer_reshaped,
 | 
			
		||||
                                                    output_buffer_reshaped,
 | 
			
		||||
                                                    logsumexp,
 | 
			
		||||
                                                    philox_seed,
 | 
			
		||||
                                                    philox_offset,
 | 
			
		||||
                                                    attn_bias,
 | 
			
		||||
                                                    cum_seq_q,
 | 
			
		||||
                                                    cum_seq_k,
 | 
			
		||||
                                                    max_q,
 | 
			
		||||
                                                    max_k,
 | 
			
		||||
                                                    dropout_p,
 | 
			
		||||
                                                    is_causal,
 | 
			
		||||
                                                    scale);
 | 
			
		||||
  return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
 | 
			
		||||
    const at::Tensor& grad_out_,
 | 
			
		||||
    const at::Tensor& query,
 | 
			
		||||
 | 
			
		||||
@ -333,14 +333,14 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
 | 
			
		||||
      weight.scalar_type() == at::ScalarType::Float ||
 | 
			
		||||
          weight.scalar_type() == at::ScalarType::Half,
 | 
			
		||||
      "'embedding_bag_byte_prepack' only support float32 or float16.");
 | 
			
		||||
  const auto weight_sizes = weight.sym_sizes();
 | 
			
		||||
  const auto cols_dim = weight.ndimension() - 1;
 | 
			
		||||
  const auto embedding_cols = weight_sizes[cols_dim];
 | 
			
		||||
  const auto weight_sizes = weight.sizes();
 | 
			
		||||
  const auto cols_dim = weight_sizes.size() - 1;
 | 
			
		||||
  const int32_t embedding_cols = static_cast<int32_t>(weight_sizes[cols_dim]);
 | 
			
		||||
  // Add 8 bytes per column to store FP32 scale and zero_point per row.
 | 
			
		||||
  const auto output_columns = embedding_cols + 2 * sizeof(float);
 | 
			
		||||
  const int32_t output_columns = static_cast<int32_t>(embedding_cols + 2 * sizeof(float));
 | 
			
		||||
 | 
			
		||||
  // Adjust output dimensions to account for FP32 scale and zero_points.
 | 
			
		||||
  auto output_shape = weight_sizes.vec();
 | 
			
		||||
  std::vector<int64_t> output_shape = weight_sizes.vec();
 | 
			
		||||
  output_shape.at(cols_dim) = output_columns;
 | 
			
		||||
  at::SymDimVector output_shape_vec(output_shape);
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,220 +0,0 @@
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/native/SparseTensorUtils.h>
 | 
			
		||||
#include <ATen/native/mps/OperationUtils.h>
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
#include <ATen/NativeFunctions.h>
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/_coalesce_native.h>
 | 
			
		||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
 | 
			
		||||
#include <ATen/ops/empty_native.h>
 | 
			
		||||
#include <ATen/ops/zeros_native.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
using namespace mps;
 | 
			
		||||
using namespace at::sparse;
 | 
			
		||||
 | 
			
		||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
 | 
			
		||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/native/mps/Sparse_metallib.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
static Tensor flatten_indices(const Tensor& indices, IntArrayRef size) {
 | 
			
		||||
 | 
			
		||||
  TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D");
 | 
			
		||||
  TORCH_CHECK(static_cast<size_t>(indices.size(0)) == size.size(),
 | 
			
		||||
              "flatten_indices: indices.size(0) must equal size.size()");
 | 
			
		||||
 | 
			
		||||
  int64_t sparse_dim = indices.size(0);
 | 
			
		||||
  int64_t nnz = indices.size(1);
 | 
			
		||||
 | 
			
		||||
  if (nnz == 0) {
 | 
			
		||||
    return at::empty({0}, indices.options().dtype(kLong));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::vector<int64_t> strides(sparse_dim);
 | 
			
		||||
  strides[sparse_dim - 1] = 1;
 | 
			
		||||
  for (int64_t i = sparse_dim - 2; i >= 0; i--) {
 | 
			
		||||
    strides[i] = strides[i + 1] * size[i + 1];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor flat_indices = at::empty({nnz}, indices.options().dtype(kLong));
 | 
			
		||||
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel");
 | 
			
		||||
      auto encoder = stream->commandEncoder();
 | 
			
		||||
      [encoder setComputePipelineState:pipeline];
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(encoder, indices, strides, flat_indices, sparse_dim, nnz);
 | 
			
		||||
      mtl_dispatch1DJob(encoder, pipeline, nnz);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return flat_indices;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static Tensor compute_output_positions(const Tensor& is_unique) {
 | 
			
		||||
 | 
			
		||||
  int64_t nnz = is_unique.size(0);
 | 
			
		||||
  if (nnz == 0) {
 | 
			
		||||
    return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor positions = at::empty({nnz}, TensorOptions().device(kMPS).dtype(kInt));
 | 
			
		||||
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pipeline = lib.getPipelineStateForFunc("compute_output_positions_kernel");
 | 
			
		||||
      auto encoder = stream->commandEncoder();
 | 
			
		||||
      [encoder setComputePipelineState:pipeline];
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(encoder, is_unique, positions);
 | 
			
		||||
      mtl_dispatch1DJob(encoder, pipeline, nnz);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return positions;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static Tensor compute_output_positions_parallel(const Tensor& is_unique) {
 | 
			
		||||
 | 
			
		||||
  int64_t nnz = is_unique.size(0);
 | 
			
		||||
  if (nnz == 0) {
 | 
			
		||||
    return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // for small arrays, use simple kernel
 | 
			
		||||
  // speed of the naive kernel drops off after 4096 nnz elements
 | 
			
		||||
  if (nnz <= 4096) {
 | 
			
		||||
    return compute_output_positions(is_unique);
 | 
			
		||||
  }
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
  Tensor positions = is_unique.to(kInt);
 | 
			
		||||
  // Kogge-Stone parallel prefix sum
 | 
			
		||||
  Tensor positions_cloned = positions.clone();
 | 
			
		||||
 | 
			
		||||
  for (int64_t stride = 1; stride < nnz; stride *= 2) {
 | 
			
		||||
    dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
      @autoreleasepool {
 | 
			
		||||
        auto pipeline = lib.getPipelineStateForFunc("kogge_stone_step");
 | 
			
		||||
        auto encoder = stream->commandEncoder();
 | 
			
		||||
        [encoder setComputePipelineState:pipeline];
 | 
			
		||||
 | 
			
		||||
        mtl_setArgs(encoder, positions, positions_cloned, stride);
 | 
			
		||||
        mtl_dispatch1DJob(encoder, pipeline, nnz);
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
    std::swap(positions, positions_cloned);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pipeline = lib.getPipelineStateForFunc("shift_right_kernel");
 | 
			
		||||
      auto encoder = stream->commandEncoder();
 | 
			
		||||
      [encoder setComputePipelineState:pipeline];
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(encoder, positions, positions_cloned);
 | 
			
		||||
      mtl_dispatch1DJob(encoder, pipeline, nnz);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  return positions_cloned;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static std::pair<Tensor, int32_t> mark_unique_and_count(const Tensor& flat_indices) {
 | 
			
		||||
 | 
			
		||||
  int64_t nnz = flat_indices.size(0);
 | 
			
		||||
  if (nnz == 0) {
 | 
			
		||||
    return {at::empty({0}, flat_indices.options().dtype(kBool)), 0};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor is_unique = at::empty({nnz}, flat_indices.options().dtype(kBool));
 | 
			
		||||
  Tensor count_result = at::zeros({1}, flat_indices.options().dtype(kInt));
 | 
			
		||||
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pipeline = lib.getPipelineStateForFunc("mark_unique_positions_and_count_kernel");
 | 
			
		||||
      auto encoder = stream->commandEncoder();
 | 
			
		||||
      [encoder setComputePipelineState:pipeline];
 | 
			
		||||
 | 
			
		||||
      mtl_setArgs(encoder, flat_indices, is_unique, count_result);
 | 
			
		||||
      mtl_dispatch1DJob(encoder, pipeline, nnz);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  int32_t num_unique = count_result.item<int32_t>();
 | 
			
		||||
 | 
			
		||||
  return {is_unique, num_unique};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SparseTensor _coalesce_sparse_mps(const SparseTensor& self) {
 | 
			
		||||
  int64_t nnz = self._nnz();
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(!self.is_coalesced());
 | 
			
		||||
  if (nnz < 2) {
 | 
			
		||||
    SparseTensor dst = self.clone();
 | 
			
		||||
    dst._coalesced_(true);
 | 
			
		||||
    return dst;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor indices = self._indices();
 | 
			
		||||
  Tensor values = self._values();
 | 
			
		||||
 | 
			
		||||
  Tensor flat_indices = flatten_indices(indices, self.sizes());
 | 
			
		||||
  Tensor sorted_order = flat_indices.argsort();
 | 
			
		||||
  Tensor flat_indices_sorted = flat_indices.index({sorted_order});
 | 
			
		||||
  values = values.index({sorted_order});
 | 
			
		||||
  indices = indices.index_select(1, sorted_order);
 | 
			
		||||
 | 
			
		||||
  auto unique_info = mark_unique_and_count(flat_indices_sorted);
 | 
			
		||||
  Tensor is_unique = unique_info.first;
 | 
			
		||||
  int32_t newNnz = unique_info.second;
 | 
			
		||||
 | 
			
		||||
  Tensor output_positions = compute_output_positions_parallel(is_unique);
 | 
			
		||||
 | 
			
		||||
  Tensor out_indices = at::empty({indices.size(0), newNnz}, indices.options());
 | 
			
		||||
  auto outValuesSize = values.sizes().vec();
 | 
			
		||||
  outValuesSize[0] = newNnz;
 | 
			
		||||
  Tensor out_values = at::zeros(outValuesSize, values.options());
 | 
			
		||||
 | 
			
		||||
  Tensor is_unique_local = is_unique;
 | 
			
		||||
  int64_t sparse_dim = indices.size(0);
 | 
			
		||||
 | 
			
		||||
  auto stream = getCurrentMPSStream();
 | 
			
		||||
  dispatch_sync_with_rethrow(stream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      auto pipeline = lib.getPipelineStateForFunc("coalesce_with_positions_kernel_" + scalarToMetalTypeString(values));
 | 
			
		||||
      auto encoder = stream->commandEncoder();
 | 
			
		||||
      [encoder setComputePipelineState:pipeline];
 | 
			
		||||
 | 
			
		||||
      const uint32_t numThreads = static_cast<uint32_t>(nnz);
 | 
			
		||||
      const uint32_t valueSize = static_cast<uint32_t>(values.numel() / nnz);
 | 
			
		||||
      mtl_setArgs(encoder,
 | 
			
		||||
                  flat_indices_sorted,
 | 
			
		||||
                  indices,
 | 
			
		||||
                  values,
 | 
			
		||||
                  is_unique_local,
 | 
			
		||||
                  output_positions,
 | 
			
		||||
                  out_indices,
 | 
			
		||||
                  out_values,
 | 
			
		||||
                  numThreads,
 | 
			
		||||
                  valueSize,
 | 
			
		||||
                  sparse_dim,
 | 
			
		||||
                  newNnz);
 | 
			
		||||
      mtl_dispatch1DJob(encoder, pipeline, nnz);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
 | 
			
		||||
  SparseTensor result = _sparse_coo_tensor_unsafe_symint(out_indices, out_values, self.sym_sizes())._coalesced_(true);
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user