Compare commits

..

2 Commits

Author SHA1 Message Date
f00a1b0349 Fix profiler stack trace names 2025-08-06 21:20:14 -07:00
bc67bce2e5 Working setup with runnable PyTorch on Codex.
Signed-off-by: Edward Yang <ezyang@meta.com>
ghstack-source-id: 132668d46021090fe3ef197fb25ba762ce42667c
Pull-Request: https://github.com/pytorch/pytorch/pull/159968
2025-08-06 14:56:40 -07:00
605 changed files with 44855 additions and 13024 deletions

View File

@ -438,7 +438,9 @@ def build_torchvision(
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
build_vars += (
f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
)
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
@ -493,7 +495,9 @@ def build_torchdata(
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
build_vars += (
f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
)
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
@ -549,7 +553,9 @@ def build_torchtext(
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
build_vars += (
f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
)
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"
@ -607,7 +613,9 @@ def build_torchaudio(
)
build_vars += f"BUILD_VERSION={version}.dev{build_date}"
elif build_version is not None:
build_vars += f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-', maxsplit=1)[0]}"
build_vars += (
f"BUILD_VERSION={build_version} PYTORCH_VERSION={branch[1:].split('-')[0]}"
)
if host.using_docker():
build_vars += " CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000"

View File

@ -176,7 +176,7 @@ case "$tag" in
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3)
if [[ $tag =~ "jammy" ]]; then
ANACONDA_PYTHON_VERSION=3.10
else
@ -190,9 +190,7 @@ case "$tag" in
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
if [[ $tag =~ "benchmarks" ]]; then
INDUCTOR_BENCHMARKS=yes
fi
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-noble-rocm-alpha-py3)
ANACONDA_PYTHON_VERSION=3.12
@ -204,6 +202,7 @@ case "$tag" in
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
INDUCTOR_BENCHMARKS=yes
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950"
;;
pytorch-linux-jammy-xpu-2025.0-py3)

View File

@ -1 +1 @@
v4.54.0
243e186efbf7fb93328dd6b34927a4e8c8f24395

View File

@ -66,9 +66,8 @@ function do_cpython_build {
ln -s pip3 ${prefix}/bin/pip
fi
# install setuptools since python 3.12 is required to use distutils
# packaging is needed to create symlink since wheel no longer provides needed information
${prefix}/bin/pip install packaging==25.0 wheel==0.45.1 setuptools==80.9.0
local abi_tag=$(${prefix}/bin/python -c "from packaging.tags import interpreter_name, interpreter_version; import sysconfig ; from sysconfig import get_config_var; print('{0}{1}-{0}{1}{2}'.format(interpreter_name(), interpreter_version(), 't' if sysconfig.get_config_var('Py_GIL_DISABLED') else ''))")
${prefix}/bin/pip install wheel==0.45.1 setuptools==80.9.0
local abi_tag=$(${prefix}/bin/python -c "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag; print('{0}{1}-{2}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag()))")
ln -sf ${prefix} /opt/python/${abi_tag}
}

View File

@ -26,15 +26,15 @@ function install_torchbench() {
python install.py --continue_on_fail
# soxr comes from https://github.com/huggingface/transformers/pull/39429
pip install transformers==4.54.0 soxr==0.5.0
# TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488
# is regressing speedup metric. This needs to be investigated further
pip install transformers==4.38.1
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
popd
chown -R jenkins torchbench
chown -R jenkins /opt/conda
}
# Pango is needed for weasyprint which is needed for doctr
@ -48,4 +48,4 @@ install_huggingface
install_timm
# Clean up
conda_run pip uninstall -y torch torchvision torchaudio triton torchao
conda_run pip uninstall -y torch torchvision torchaudio triton

View File

@ -34,27 +34,18 @@ function install_ubuntu() {
# The xpu-smi packages
apt-get install -y flex bison xpu-smi
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
# Compute and Media Runtimes
apt-get install -y \
intel-opencl-icd intel-level-zero-gpu level-zero \
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
# Development Packages
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
else # rolling driver
apt-get install -y \
intel-opencl-icd libze-intel-gpu1 libze1 \
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
# Compute and Media Runtimes
apt-get install -y \
intel-opencl-icd intel-level-zero-gpu level-zero \
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then
apt-get install -y intel-ocloc
fi
# Development Packages
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
# Install Intel Support Packages
apt-get install -y ${XPU_PACKAGES}
@ -139,11 +130,11 @@ function install_sles() {
}
# Default use GPU driver rolling releases
XPU_DRIVER_VERSION=""
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
# Use GPU driver LTS releases
XPU_DRIVER_VERSION="/lts/2350"
# Default use GPU driver LTS releases
XPU_DRIVER_VERSION="/lts/2350"
if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then
# Use GPU driver rolling releases
XPU_DRIVER_VERSION=""
fi
# Default use Intel® oneAPI Deep Learning Essentials 2025.0

View File

@ -63,12 +63,11 @@ lark==0.12.0
#Pinned versions: 0.12.0
#test that import:
librosa>=0.6.2 ; python_version < "3.11" and platform_machine != "s390x"
librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x"
librosa>=0.6.2 ; python_version < "3.11"
librosa==0.10.2 ; python_version == "3.12"
#Description: A python package for music and audio analysis
#Pinned versions: >=0.6.2
#test that import: test_spectral_ops.py
#librosa depends on numba; disable it for s390x while numba is disabled too
#mkl #this breaks linux-bionic-rocm4.5-py3.7
#Description: Intel oneAPI Math Kernel Library
@ -111,15 +110,14 @@ ninja==1.11.1.3
#Pinned versions: 1.11.1.3
#test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py
numba==0.49.0 ; python_version < "3.9" and platform_machine != "s390x"
numba==0.55.2 ; python_version == "3.9" and platform_machine != "s390x"
numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x"
numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
numba==0.49.0 ; python_version < "3.9"
numba==0.55.2 ; python_version == "3.9"
numba==0.55.2 ; python_version == "3.10"
numba==0.60.0 ; python_version == "3.12"
#Description: Just-In-Time Compiler for Numerical Functions
#Pinned versions: 0.54.1, 0.49.0, <=0.49.1
#test that import: test_numba_integration.py
#For numba issue see https://github.com/pytorch/pytorch/issues/51511
#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073
#numpy
#Description: Provides N-dimensional arrays and linear algebra
@ -309,7 +307,7 @@ pytest-cpp==2.3.0
#Pinned versions: 2.3.0
#test that import:
z3-solver==4.15.1.0 ; platform_machine != "s390x"
z3-solver==4.15.1.0
#Description: The Z3 Theorem Prover Project
#Pinned versions:
#test that import:

View File

@ -138,11 +138,28 @@ fi
echo "Calling setup.py bdist at $(date)"
time CMAKE_ARGS=${CMAKE_ARGS[@]} \
EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)"
time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \
BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 \
BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \
USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \
python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR
echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)"
echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)"
time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \
BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 \
BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \
USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \
CMAKE_FRESH=1 python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR
echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)"
else
time CMAKE_ARGS=${CMAKE_ARGS[@]} \
EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \
BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \
USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \
python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR
fi
echo "Finished setup.py bdist at $(date)"
# Build libtorch packages
@ -255,6 +272,10 @@ ls /tmp/$WHEELHOUSE_DIR
mkdir -p "/$WHEELHOUSE_DIR"
mv /tmp/$WHEELHOUSE_DIR/torch*linux*.whl /$WHEELHOUSE_DIR/
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
mv /tmp/$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/ || true
fi
if [[ -n "$BUILD_PYTHONLESS" ]]; then
mkdir -p /$LIBTORCH_HOUSE_DIR
mv /tmp/$LIBTORCH_HOUSE_DIR/*.zip /$LIBTORCH_HOUSE_DIR
@ -431,8 +452,16 @@ if [[ -z "$BUILD_PYTHONLESS" ]]; then
pushd $PYTORCH_ROOT/test
# Install the wheel for this Python version
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
pip uninstall -y "$TORCH_NO_PYTHON_PACKAGE_NAME" || true
fi
pip uninstall -y "$TORCH_PACKAGE_NAME"
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
pip install "$TORCH_NO_PYTHON_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v
fi
pip install "$TORCH_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v
# Print info on the libraries installed in this wheel

View File

@ -50,6 +50,9 @@ if [[ ${BUILD_ENVIRONMENT} == *"parallelnative"* ]]; then
export ATEN_THREADING=NATIVE
fi
# Enable LLVM dependency for TensorExpr testing
export USE_LLVM=/opt/llvm
export LLVM_DIR=/opt/llvm/lib/cmake/llvm
if ! which conda; then
# In ROCm CIs, we are doing cross compilation on build machines with
@ -173,7 +176,7 @@ fi
# We only build FlashAttention files for CUDA 8.0+, and they require large amounts of
# memory to build and will OOM
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && echo "${TORCH_CUDA_ARCH_LIST}" | tr ' ' '\n' | sed 's/$/>= 8.0/' | bc | grep -q 1; then
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ 1 -eq $(echo "${TORCH_CUDA_ARCH_LIST} >= 8.0" | bc) ]]; then
export BUILD_CUSTOM_STEP="ninja -C build flash_attention -j 2"
fi
@ -189,6 +192,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then
export USE_ASAN=1
export REL_WITH_DEB_INFO=1
export UBSAN_FLAGS="-fno-sanitize-recover=all"
unset USE_LLVM
fi
if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then
@ -261,13 +265,22 @@ else
WERROR=1 python setup.py clean
WERROR=1 python setup.py bdist_wheel
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
python3 tools/packaging/split_wheel.py bdist_wheel
else
WERROR=1 python setup.py bdist_wheel
fi
else
python setup.py clean
if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then
source .ci/pytorch/install_cache_xla.sh
fi
python setup.py bdist_wheel
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
echo "USE_SPLIT_BUILD cannot be used with xla or rocm"
exit 1
else
python setup.py bdist_wheel
fi
fi
pip_install_whl "$(echo dist/*.whl)"

View File

@ -175,9 +175,6 @@ checkout_install_torchbench() {
python install.py --continue_on_fail
fi
# soxr comes from https://github.com/huggingface/transformers/pull/39429
pip install transformers==4.54.0 soxr==0.5.0
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
popd

View File

@ -1051,10 +1051,20 @@ test_libtorch_api() {
mkdir -p $TEST_REPORTS_DIR
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" "$TORCH_BIN_DIR"/test_api --gtest_filter='-IMethodTest.*' --gtest_output=xml:$TEST_REPORTS_DIR/test_api.xml
"$TORCH_BIN_DIR"/test_tensorexpr --gtest_output=xml:$TEST_REPORTS_DIR/test_tensorexpr.xml
else
# Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest"
# On s390x, pytorch is built without llvm.
# Even if it would be built with llvm, llvm currently doesn't support used features on s390x and
# test fails with errors like:
# JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer
# unknown file: Failure
# C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) }
if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
fi
fi
# quantization is not fully supported on s390x yet
@ -1682,6 +1692,7 @@ elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then
elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
install_torchaudio
install_torchvision
install_torchao
id=$((SHARD_NUMBER-1))
# https://github.com/opencv/opencv-python/issues/885
pip_install opencv-python==4.8.0.74

View File

@ -61,10 +61,9 @@ if "%USE_XPU%"=="1" (
call "C:\Program Files (x86)\Intel\oneAPI\compiler\latest\env\vars.bat"
call "C:\Program Files (x86)\Intel\oneAPI\ocloc\latest\env\vars.bat"
if errorlevel 1 exit /b 1
:: Reduce build time
SET TORCH_XPU_ARCH_LIST=bmg
:: Re-setup python env for build
call pip install -r requirements.txt
:: Reduce build time. Only have MTL self-hosted runner now
SET TORCH_XPU_ARCH_LIST=xe-lpg
SET USE_KINETO=0
)
@echo on

View File

@ -192,6 +192,9 @@ retry brew install libomp
# For USE_DISTRIBUTED=1 on macOS, need libuv, which is build as part of tensorpipe submodule
export USE_DISTRIBUTED=1
if [[ -n "$CROSS_COMPILE_ARM64" ]]; then
export CMAKE_OSX_ARCHITECTURES=arm64
fi
export USE_MKLDNN=OFF
export USE_QNNPACK=OFF
export BUILD_TEST=OFF
@ -199,7 +202,16 @@ export BUILD_TEST=OFF
pushd "$pytorch_rootdir"
echo "Calling setup.py bdist_wheel at $(date)"
python setup.py bdist_wheel -d "$whl_tmp_dir"
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)"
BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 python setup.py bdist_wheel -d "$whl_tmp_dir"
echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)"
echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)"
BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 CMAKE_FRESH=1 python setup.py bdist_wheel -d "$whl_tmp_dir"
echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)"
else
python setup.py bdist_wheel -d "$whl_tmp_dir"
fi
echo "Finished setup.py bdist_wheel at $(date)"

View File

@ -65,8 +65,16 @@ fi
if [[ "$PACKAGE_TYPE" != libtorch ]]; then
if [[ "\$BUILD_ENVIRONMENT" != *s390x* ]]; then
pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}"
retry pip install -q numpy protobuf typing-extensions
if [[ "$USE_SPLIT_BUILD" == "true" ]]; then
pkg_no_python="$(ls -1 /final_pkgs/torch_no_python* | sort |tail -1)"
pkg_torch="$(ls -1 /final_pkgs/torch-* | sort |tail -1)"
# todo: after folder is populated use the pypi_pkg channel instead
pip install "\$pkg_no_python" "\$pkg_torch" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}_pypi_pkg"
retry pip install -q numpy protobuf typing-extensions
else
pip install "\$pkg" --index-url "https://download.pytorch.org/whl/\${CHANNEL}/${DESIRED_CUDA}"
retry pip install -q numpy protobuf typing-extensions
fi
else
pip install "\$pkg"
retry pip install -q numpy protobuf typing-extensions

View File

@ -134,6 +134,7 @@ export DESIRED_PYTHON="${DESIRED_PYTHON:-}"
export DESIRED_CUDA="$DESIRED_CUDA"
export LIBTORCH_VARIANT="${LIBTORCH_VARIANT:-}"
export BUILD_PYTHONLESS="${BUILD_PYTHONLESS:-}"
export USE_SPLIT_BUILD="${USE_SPLIT_BUILD:-}"
if [[ "${OSTYPE}" == "msys" ]]; then
export LIBTORCH_CONFIG="${LIBTORCH_CONFIG:-}"
if [[ "${LIBTORCH_CONFIG:-}" == 'debug' ]]; then

View File

@ -23,6 +23,10 @@ if [[ "${DRY_RUN}" = "disabled" ]]; then
AWS_S3_CP="aws s3 cp"
fi
if [[ "${USE_SPLIT_BUILD:-false}" == "true" ]]; then
UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_pypi_pkg"
fi
# this is special build with all dependencies packaged
if [[ ${BUILD_NAME} == *-full* ]]; then
UPLOAD_SUBFOLDER="${UPLOAD_SUBFOLDER}_full"

View File

@ -24,6 +24,7 @@ runs:
-e PYTORCH_FINAL_PACKAGE_DIR \
-e PYTORCH_ROOT \
-e SKIP_ALL_TESTS \
-e USE_SPLIT_BUILD \
--tty \
--detach \
-v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \

View File

@ -1 +1 @@
bdb88e1d66f272cad72156c90ac8428ca61a601c
6fbc710b617f79b992ef2ebc7f95e818aa390293

View File

@ -1 +1 @@
458e74eb907f96069e6d8a4f3c9f457001fef2ea
6a39ba85fe0f2fff9494b5eccea717c93510c230

View File

@ -1 +1 @@
095faec1e7b6cc47220181e74ae9cde2605f9b00
b6a5b82b9948b610fa4c304d0d869c82b8f17db1

View File

@ -273,6 +273,7 @@ def generate_wheels_matrix(
os: str,
arches: Optional[list[str]] = None,
python_versions: Optional[list[str]] = None,
use_split_build: bool = False,
) -> list[dict[str, str]]:
package_type = "wheel"
if os == "linux" or os == "linux-aarch64" or os == "linux-s390x":
@ -320,6 +321,15 @@ def generate_wheels_matrix(
):
continue
if use_split_build and (
arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux"
):
raise RuntimeError(
"Split build is only supported on linux with cuda 12* and cpu.\n"
f"Currently attempting to build on arch version {arch_version} and os {os}.\n"
"Please modify the matrix generation to exclude this combination."
)
# cuda linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install
if (
@ -334,6 +344,7 @@ def generate_wheels_matrix(
"gpu_arch_type": gpu_arch_type,
"gpu_arch_version": gpu_arch_version,
"desired_cuda": desired_cuda,
"use_split_build": "True" if use_split_build else "False",
"container_image": WHEEL_CONTAINER_IMAGES[arch_version].split(
":"
)[0],
@ -366,6 +377,7 @@ def generate_wheels_matrix(
"desired_cuda": translate_desired_cuda(
gpu_arch_type, gpu_arch_version
),
"use_split_build": "True" if use_split_build else "False",
"container_image": WHEEL_CONTAINER_IMAGES[
arch_version
].split(":")[0],
@ -388,6 +400,7 @@ def generate_wheels_matrix(
"desired_cuda": translate_desired_cuda(
gpu_arch_type, gpu_arch_version
),
"use_split_build": "True" if use_split_build else "False",
"container_image": WHEEL_CONTAINER_IMAGES[arch_version].split(
":"
)[0],

View File

@ -59,7 +59,9 @@ class BinaryBuildWorkflow:
is_scheduled: str = ""
branches: str = "nightly"
# Mainly for macos
cross_compile_arm64: bool = False
macos_runner: str = "macos-14-xlarge"
use_split_build: bool = False
# Mainly used for libtorch builds
build_variant: str = ""
@ -70,6 +72,9 @@ class BinaryBuildWorkflow:
for item in [self.os, "binary", self.package_type, self.build_variant]
if item != ""
)
if self.use_split_build:
# added to distinguish concurrency groups
self.build_environment += "-split"
def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
output_file_path = (
@ -112,6 +117,21 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
isolated_workflow=True,
),
),
# See https://github.com/pytorch/pytorch/issues/138750
# BinaryBuildWorkflow(
# os=OperatingSystem.LINUX,
# package_type="manywheel",
# build_configs=generate_binary_build_matrix.generate_wheels_matrix(
# OperatingSystem.LINUX,
# use_split_build=True,
# arches=["11.8", "12.1", "12.4", "cpu"],
# ),
# ciflow_config=CIFlowConfig(
# labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
# isolated_workflow=True,
# ),
# use_split_build=True,
# ),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="libtorch",
@ -155,11 +175,27 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
package_type="manywheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
arches=["12.8"],
python_versions=["3.12"],
arches=["12.6", "12.8", "12.9"],
python_versions=["3.9"],
),
branches="main",
),
# See https://github.com/pytorch/pytorch/issues/138750
# BinaryBuildWorkflow(
# os=OperatingSystem.LINUX,
# package_type="manywheel",
# build_configs=generate_binary_build_matrix.generate_wheels_matrix(
# OperatingSystem.LINUX,
# arches=["11.8", "12.1", "12.4"],
# python_versions=["3.9"],
# use_split_build=True,
# ),
# ciflow_config=CIFlowConfig(
# labels={LABEL_CIFLOW_PERIODIC},
# ),
# branches="main",
# use_split_build=True,
# ),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="libtorch",
@ -302,6 +338,7 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
generate_binary_build_matrix.RELEASE,
libtorch_variants=["shared-with-deps"],
),
cross_compile_arm64=False,
macos_runner="macos-14-xlarge",
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_LIBTORCH},
@ -314,6 +351,7 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.MACOS_ARM64
),
cross_compile_arm64=False,
macos_runner="macos-14-xlarge",
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},

View File

@ -262,12 +262,7 @@ def is_exception_branch(branch: str) -> bool:
"""
Branches that get opted out of experiments by default, until they're explicitly enabled.
"""
return branch.split("/", maxsplit=1)[0] in {
"main",
"nightly",
"release",
"landchecks",
}
return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"}
def load_yaml(yaml_text: str) -> Any:

View File

@ -47,6 +47,9 @@ env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SKIP_ALL_TESTS: 0
{%- if cross_compile_arm64 %}
CROSS_COMPILE_ARM64: 1
{% endif %}
!{{ common.concurrency(build_environment) }}
jobs:

View File

@ -25,6 +25,11 @@
DOCKER_IMAGE: !{{ config["container_image"] }}
DOCKER_IMAGE_TAG_PREFIX: !{{ config["container_image_tag_prefix"] }}
{%- endif %}
{%- if config["package_type"] == "manywheel" %}
{%- if config.use_split_build is defined %}
use_split_build: !{{ config["use_split_build"] }}
{%- endif %}
{%- endif %}
{%- if config["package_type"] == "libtorch" %}
{%- if config["libtorch_config"] %}
LIBTORCH_CONFIG: !{{ config["libtorch_config"] }}

View File

@ -26,6 +26,13 @@ on:
default: 240
type: number
description: timeout for the job
use_split_build:
description: |
[Experimental] Build a libtorch only wheel and build pytorch such that
are built from the libtorch wheel.
required: false
type: boolean
default: false
ALPINE_IMAGE:
required: false
type: string
@ -110,6 +117,7 @@ jobs:
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
steps:
- name: Make the env permanent during this workflow (but not the secrets)
shell: bash
@ -134,6 +142,7 @@ jobs:
echo "PR_NUMBER=${{ env.PR_NUMBER }}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
echo "SHA1=${{ env.SHA1 }}"
echo "USE_SPLIT_BUILD=${{ env.use_split_build }}"
} >> "${GITHUB_ENV} }}"
- name: List the env
@ -252,6 +261,7 @@ jobs:
-e PYTORCH_ROOT \
-e SKIP_ALL_TESTS \
-e PYTORCH_EXTRA_INSTALL_REQUIREMENTS \
-e USE_SPLIT_BUILD \
--tty \
--detach \
-v "${GITHUB_WORKSPACE}/pytorch:/pytorch" \

View File

@ -64,6 +64,13 @@ on:
required: true
type: string
description: Hardware to run this job on. Valid values are linux.4xlarge, linux.4xlarge.nvidia.gpu, linux.arm64.2xlarge, and linux.rocm.gpu
use_split_build:
description: |
[Experimental] Build a libtorch only wheel and build pytorch such that
are built from the libtorch wheel.
required: false
type: boolean
default: false
secrets:
github-token:
required: true
@ -97,6 +104,7 @@ jobs:
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
steps:
- name: Make the env permanent during this workflow (but not the secrets)
shell: bash
@ -121,6 +129,7 @@ jobs:
echo "PR_NUMBER=${{ env.PR_NUMBER }}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
echo "SHA1=${{ env.SHA1 }}"
echo "USE_SPLIT_BUILD=${{ env.USE_SPLIT_BUILD }}"
} >> "${GITHUB_ENV} }}"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"

View File

@ -51,6 +51,13 @@ on:
required: false
type: string
description: Desired python version
use_split_build:
description: |
[Experimental] Build a libtorch only wheel and build pytorch such that
are built from the libtorch wheel.
required: false
type: boolean
default: false
secrets:
github-token:
required: true
@ -79,6 +86,7 @@ jobs:
PR_NUMBER: ${{ github.event.pull_request.number }}
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
USE_SPLIT_BUILD: ${{ inputs.use_split_build }}
steps:
- name: Checkout PyTorch
uses: pytorch/pytorch/.github/actions/checkout-pytorch@main

View File

@ -306,6 +306,7 @@ jobs:
-e OUR_GITHUB_JOB_ID \
-e HUGGING_FACE_HUB_TOKEN \
-e SCRIBE_GRAPHQL_ACCESS_TOKEN \
-e USE_SPLIT_BUILD \
-e BUILD_ADDITIONAL_PACKAGES \
--memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \
--memory-swap="${TOTAL_MEMORY_WITH_SWAP}g" \

View File

@ -61,7 +61,6 @@ jobs:
pytorch-linux-jammy-rocm-n-py3,
pytorch-linux-noble-rocm-n-py3,
pytorch-linux-noble-rocm-alpha-py3,
pytorch-linux-jammy-rocm-n-py3-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-clang12,
pytorch-linux-jammy-py3.9-gcc11,
pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks,

View File

@ -60,6 +60,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -83,6 +84,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -106,6 +108,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cpu-aarch64
secrets:
@ -126,6 +129,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -152,6 +156,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda-aarch64-12_9
secrets:
@ -171,6 +176,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.10"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -194,6 +200,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -217,6 +224,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cpu-aarch64
secrets:
@ -237,6 +245,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.10"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -263,6 +272,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cuda-aarch64-12_9
secrets:
@ -282,6 +292,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.11"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -305,6 +316,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -328,6 +340,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cpu-aarch64
secrets:
@ -348,6 +361,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.11"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -374,6 +388,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cuda-aarch64-12_9
secrets:
@ -393,6 +408,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.12"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -416,6 +432,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -439,6 +456,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cpu-aarch64
secrets:
@ -459,6 +477,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.12"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -485,6 +504,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cuda-aarch64-12_9
secrets:
@ -504,6 +524,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.13"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -527,6 +548,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.13"
build_name: manywheel-py3_13-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -550,6 +572,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.13"
build_name: manywheel-py3_13-cpu-aarch64
secrets:
@ -570,6 +593,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.13"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -596,6 +620,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.13"
build_name: manywheel-py3_13-cuda-aarch64-12_9
secrets:
@ -615,6 +640,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.13t"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -638,6 +664,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.13t"
build_name: manywheel-py3_13t-cpu-aarch64
build_environment: linux-aarch64-binary-manywheel
@ -661,6 +688,7 @@ jobs:
GPU_ARCH_TYPE: cpu-aarch64
DOCKER_IMAGE: manylinux2_28_aarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64
use_split_build: False
DESIRED_PYTHON: "3.13t"
build_name: manywheel-py3_13t-cpu-aarch64
secrets:
@ -681,6 +709,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.13t"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.arm64.m7g.4xlarge.ephemeral
@ -707,6 +736,7 @@ jobs:
GPU_ARCH_TYPE: cuda-aarch64
DOCKER_IMAGE: manylinuxaarch64-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.13t"
build_name: manywheel-py3_13t-cuda-aarch64-12_9
secrets:

View File

@ -42,7 +42,54 @@ jobs:
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
manywheel-py3_12-cuda12_8-build:
manywheel-py3_9-cuda12_6-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu126
GPU_ARCH_VERSION: 12.6
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.6
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda12_6
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.6.80; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.6.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.0.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.7.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.1.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.4.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.6.77; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.6.85; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.11.1.6; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_6-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_9-cuda12_6-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu126
GPU_ARCH_VERSION: 12.6
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.6
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_6
build_environment: linux-binary-manywheel
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.4xlarge.nvidia.gpu # for other cuda versions, we use 4xlarge runner
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_8-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
@ -56,17 +103,18 @@ jobs:
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.8
DESIRED_PYTHON: "3.12"
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_12-cuda12_8
build_name: manywheel-py3_9-cuda12_8
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.8.4.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.3.3.83; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.9.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.3.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.8.90; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.8.93; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.13.1.3; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_12-cuda12_8-test: # Testing
manywheel-py3_9-cuda12_8-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_12-cuda12_8-build
- manywheel-py3_9-cuda12_8-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
@ -79,8 +127,56 @@ jobs:
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.8
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cuda12_8
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_8
build_environment: linux-binary-manywheel
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_9-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
needs: get-label-type
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: 12.9
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-cuda12_9
build_environment: linux-binary-manywheel
PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.9.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.4.1.4; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.10.19; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.7.5.82; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.5.10.65; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.27.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvshmem-cu12==3.3.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.9.79; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.9.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufile-cu12==1.14.1.1; platform_system == 'Linux' and platform_machine == 'x86_64'
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
manywheel-py3_9-cuda12_9-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs:
- manywheel-py3_9-cuda12_9-build
- get-label-type
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
PACKAGE_TYPE: manywheel
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu129
GPU_ARCH_VERSION: 12.9
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: cuda12.9
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cuda12_9
build_environment: linux-binary-manywheel
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
runs_on: linux.g4dn.4xlarge.nvidia.gpu # 12.8 and 12.9 build need sm_70+ runner

File diff suppressed because it is too large Load Diff

View File

@ -58,6 +58,7 @@ jobs:
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
use_split_build: False
DESIRED_PYTHON: "3.9"
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build_name: manywheel-py3_9-rocm6_4
@ -82,6 +83,7 @@ jobs:
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: manylinux2_28-builder
DOCKER_IMAGE_TAG_PREFIX: rocm6.4
use_split_build: False
DESIRED_PYTHON: "3.9"
steps:
- name: Setup ROCm

View File

@ -60,6 +60,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.9"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -83,6 +84,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -105,6 +107,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.9"
build_name: manywheel-py3_9-cpu-s390x
secrets:
@ -124,6 +127,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.10"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -147,6 +151,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -169,6 +174,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.10"
build_name: manywheel-py3_10-cpu-s390x
secrets:
@ -188,6 +194,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.11"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -211,6 +218,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -233,6 +241,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.11"
build_name: manywheel-py3_11-cpu-s390x
secrets:
@ -252,6 +261,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.12"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -275,6 +285,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -297,6 +308,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.12"
build_name: manywheel-py3_12-cpu-s390x
secrets:
@ -316,6 +328,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.13"
runs_on: linux.s390x
ALPINE_IMAGE: "docker.io/s390x/alpine"
@ -339,6 +352,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.13"
build_name: manywheel-py3_13-cpu-s390x
build_environment: linux-s390x-binary-manywheel
@ -361,6 +375,7 @@ jobs:
GPU_ARCH_TYPE: cpu-s390x
DOCKER_IMAGE: pytorch/manylinuxs390x-builder
DOCKER_IMAGE_TAG_PREFIX: cpu-s390x
use_split_build: False
DESIRED_PYTHON: "3.13"
build_name: manywheel-py3_13-cpu-s390x
secrets:

View File

@ -85,7 +85,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-rocm-py3_10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },

View File

@ -77,7 +77,7 @@ jobs:
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-rocm-py3_10
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [

View File

@ -254,6 +254,68 @@ jobs:
timeout-minutes: 600
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-build-distributed:
name: linux-jammy-cuda12.8-py3.10-gcc11-build-distributed
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '7.5'
test-matrix: |
{ include: [
{ config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-test-distributed:
name: linux-jammy-cuda12.8-py3.10-gcc11-test
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build-distributed
- target-determination
with:
timeout-minutes: 360
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build-distributed.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-build:
name: linux-jammy-cuda12.8-py3.10-gcc11
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-test:
name: linux-jammy-cuda12.8-py3.10-gcc11
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build
- target-determination
with:
timeout-minutes: 360
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build:
name: linux-jammy-cuda12.8-cudnn9-py3.9-clang12
uses: ./.github/workflows/_linux-build.yml
@ -268,6 +330,30 @@ jobs:
]}
secrets: inherit
linux-jammy-py3_9-clang9-xla-build:
name: linux-jammy-py3_9-clang9-xla
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.9-clang9-xla
docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite
test-matrix: |
{ include: [
{ config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" },
]}
secrets: inherit
linux-jammy-py3_9-clang9-xla-test:
name: linux-jammy-py3_9-clang9-xla
uses: ./.github/workflows/_linux-test.yml
needs: linux-jammy-py3_9-clang9-xla-build
with:
build-environment: linux-jammy-py3.9-clang9-xla
docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cpu-py3_10-gcc11-bazel-test:
name: linux-jammy-cpu-py3.10-gcc11-bazel-test
uses: ./.github/workflows/_bazel-build-test.yml
@ -343,6 +429,31 @@ jobs:
test-matrix: ${{ needs.linux-jammy-py3-clang12-executorch-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc9-inductor-build:
name: cuda12.8-py3.10-gcc9-sm75
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '7.5'
test-matrix: |
{ include: [
{ config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc9-inductor-test:
name: cuda12.8-py3.10-gcc9-sm75
uses: ./.github/workflows/_linux-test.yml
needs: linux-jammy-cuda12_8-py3_10-gcc9-inductor-build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm75
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-xpu-2025_1-py3_9-build:
name: linux-jammy-xpu-2025.1-py3.9
uses: ./.github/workflows/_linux-build.yml

View File

@ -63,43 +63,6 @@ jobs:
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-build:
name: linux-jammy-cuda12.8-py3.10-gcc11
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: '7.5 8.9'
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" },
{ config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-test:
name: linux-jammy-cuda12.8-py3.10-gcc11
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-build
- target-determination
with:
timeout-minutes: 360
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.test-matrix }}
secrets: inherit
# no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated
linux-jammy-cuda12_8-py3_10-gcc11-no-ops-build:
name: linux-jammy-cuda12.8-py3.10-gcc11-no-ops

View File

@ -12,9 +12,7 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
permissions: read-all
jobs:
# There must be at least one job here to satisfy GitHub action workflow syntax
@ -53,27 +51,3 @@ jobs:
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-py3_9-clang9-xla-build:
name: linux-jammy-py3_9-clang9-xla
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.9-clang9-xla
docker-image-name: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/xla_base:v1.3-lite
test-matrix: |
{ include: [
{ config: "xla", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" },
]}
secrets: inherit
linux-jammy-py3_9-clang9-xla-test:
name: linux-jammy-py3_9-clang9-xla
uses: ./.github/workflows/_linux-test.yml
needs: linux-jammy-py3_9-clang9-xla-build
with:
build-environment: linux-jammy-py3.9-clang9-xla
docker-image: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-py3_9-clang9-xla-build.outputs.test-matrix }}
secrets: inherit

3
.gitignore vendored
View File

@ -146,9 +146,6 @@ merge_record.json
torchgen/packaged/*
!torchgen/packaged/README.md
# This file is injected by ROCm build scripts to bootstrap in torch/__init__.py.
torch/_rocm_init.py
# IPython notebook checkpoints
.ipynb_checkpoints

View File

@ -1452,6 +1452,8 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'--no-black-binary',
'black==23.12.1',
'usort==1.0.8.post1',
'isort==6.0.1',
'ruff==0.12.2', # sync with RUFF

12
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,12 @@
repos:
- repo: local
hooks:
- id: lintrunner
name: Run Lintrunner in an isolated venv before every push. The first run may be slow...
entry: python scripts/run_lintrunner.py # wrapper below
language: python # precommit manages venv for the wrapper
additional_dependencies: [] # wrapper handles lintrunner install
always_run: true
stages: [pre-push] # fire only on prepush
pass_filenames: false # Lintrunner gets no perfile args
verbose: true # stream output as it is produced...allegedly anyways

View File

@ -8,7 +8,8 @@
Instead run only a single test case, e.g., 'python test/test_torch.py TestTorch.test_dir'
- Do NOT run setup.py, you do not have a working build environment
- Do NOT run pre-commit, it is not setup
- To run lint, run 'lintrunner -a' (which will autoapply changes)
- To run lint, run 'lintrunner -a' (which will autoapply changes). lintrunner
ONLY accepts this flag, do not try to run on individual files.
- Do NOT attempt to install dependencies, you do not have Internet access
- When you are ready to make a PR, do exactly these steps:
- git stash -u

View File

@ -239,9 +239,7 @@ option(USE_XPU "Use XPU" ON)
cmake_dependent_option(
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX OR WIN32" OFF)
cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF)
option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF)
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
@ -253,6 +251,7 @@ cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF)
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
option(USE_GFLAGS "Use GFLAGS" OFF)
option(USE_GLOG "Use GLOG" OFF)
option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
@ -261,13 +260,11 @@ option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF)
option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF)
option(USE_NATIVE_ARCH "Use -march=native" OFF)
cmake_dependent_option(USE_MPS "Use MPS for macOS build" ON "MPS_FOUND" OFF)
option(USE_DISTRIBUTED "Use distributed" ON)
cmake_dependent_option(USE_NCCL "Use NCCL" ON
"USE_DISTRIBUTED;USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
"USE_CUDA OR USE_ROCM;UNIX;NOT APPLE" OFF)
cmake_dependent_option(USE_XCCL "Use XCCL" ON
"USE_XPU;UNIX;NOT APPLE" OFF)
cmake_dependent_option(USE_RCCL "Use RCCL" ON USE_NCCL OFF)
cmake_dependent_option(USE_RCCL "Use RCCL" ON "USE_NCCL;NOT WIN32" OFF)
cmake_dependent_option(USE_STATIC_NCCL "Use static NCCL" OFF "USE_NCCL" OFF)
cmake_dependent_option(USE_SYSTEM_NCCL "Use system-wide NCCL" OFF "USE_NCCL"
OFF)
@ -325,6 +322,7 @@ set(MKLDNN_ENABLE_CONCURRENT_EXEC ${USE_MKLDNN})
cmake_dependent_option(USE_MKLDNN_CBLAS "Use CBLAS in MKLDNN" OFF "USE_MKLDNN"
OFF)
option(USE_STATIC_MKL "Prefer to link with MKL statically (Unix only)" OFF)
option(USE_DISTRIBUTED "Use distributed" ON)
cmake_dependent_option(
USE_MPI "Use MPI for Caffe2. Only available if USE_DISTRIBUTED is on." ON
"USE_DISTRIBUTED" OFF)
@ -836,11 +834,10 @@ include(ExternalProject)
# ---[ Dependencies ---[ FBGEMM doesn't work on x86 32bit and
# CMAKE_SYSTEM_PROCESSOR thinks its 64bit
if(USE_FBGEMM AND NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
message(WARNING
"x64 operating system is required for FBGEMM. "
"Not compiling with FBGEMM. "
"Turn this warning off by USE_FBGEMM=OFF.")
if(USE_FBGEMM
AND((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VOID_P EQUAL
4)
OR CMAKE_SYSTEM_PROCESSOR STREQUAL "x86"))
set(USE_FBGEMM OFF)
endif()

View File

@ -164,7 +164,6 @@ caffe2/utils/hip @jeffdaily @jithunnair-amd
# torch.export
/torch/export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi
/torch/_export/ @avikchaudhuri @tugsbayasgalan @zhxchen17 @ydwu4 @angelayi
/torch/_export/serde/schema.py @SherlockNoMad @zhxchen17
# Dynamic Shapes
/torch/fx/experimental/symbolic_shapes.py @bobrenjc93 @laithsakka

View File

@ -1,4 +1,4 @@
![PyTorch Logo](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/pytorch-logo-dark.png)
![PyTorch Logo](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/pytorch-logo-dark.png)
--------------------------------------------------------------------------------
@ -72,7 +72,7 @@ Elaborating Further:
If you use NumPy, then you have used Tensors (a.k.a. ndarray).
![Tensor illustration](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/tensor_illustration.png)
![Tensor illustration](./docs/source/_static/img/tensor_illustration.png)
PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the
computation by a huge amount.
@ -99,7 +99,7 @@ from several research papers on this topic, as well as current and past work suc
While this technique is not unique to PyTorch, it's one of the fastest implementations of it to date.
You get the best of speed and flexibility for your crazy research.
![Dynamic graph](https://github.com/pytorch/pytorch/blob/9708fcf92db88b80b9010c68662d634434da3106/docs/source/_static/img/dynamic_graph.gif)
![Dynamic graph](https://github.com/pytorch/pytorch/raw/main/docs/source/_static/img/dynamic_graph.gif)
### Python First
@ -243,7 +243,7 @@ git submodule update --init --recursive
```bash
conda install cmake ninja
# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section above
# Run this command from the PyTorch directory after cloning the source code using the “Get the PyTorch Source“ section below
pip install -r requirements.txt
```
@ -560,7 +560,7 @@ To learn more about making a contribution to Pytorch, please see our [Contributi
PyTorch is a community-driven project with several skillful engineers and researchers contributing to it.
PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), [Alban Desmaison](https://github.com/albanD), [Piotr Bialecki](https://github.com/ptrblck) and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means.
PyTorch is currently maintained by [Soumith Chintala](http://soumith.ch), [Gregory Chanan](https://github.com/gchanan), [Dmytro Dzhulgakov](https://github.com/dzhulgakov), [Edward Yang](https://github.com/ezyang), and [Nikita Shulga](https://github.com/malfet) with major contributions coming from hundreds of talented individuals in various forms and means.
A non-exhaustive but growing list needs to mention: [Trevor Killeen](https://github.com/killeent), [Sasank Chilamkurthy](https://github.com/chsasank), [Sergey Zagoruyko](https://github.com/szagoruyko), [Adam Lerer](https://github.com/adamlerer), [Francisco Massa](https://github.com/fmassa), [Alykhan Tejani](https://github.com/alykhantejani), [Luca Antiga](https://github.com/lantiga), [Alban Desmaison](https://github.com/albanD), [Andreas Koepf](https://github.com/andreaskoepf), [James Bradbury](https://github.com/jekbradbury), [Zeming Lin](https://github.com/ebetica), [Yuandong Tian](https://github.com/yuandong-tian), [Guillaume Lample](https://github.com/glample), [Marat Dukhan](https://github.com/Maratyszcza), [Natalia Gimelshein](https://github.com/ngimel), [Christian Sarofeen](https://github.com/csarofeen), [Martin Raison](https://github.com/martinraison), [Edward Yang](https://github.com/ezyang), [Zachary Devito](https://github.com/zdevito). <!-- codespell:ignore -->
Note: This project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch.

View File

@ -119,8 +119,6 @@ file(GLOB_RECURSE native_mps_cpp "native/mps/*.cpp")
file(GLOB_RECURSE native_mps_mm "native/mps/*.mm")
file(GLOB_RECURSE native_mps_metal "native/mps/*.metal")
file(GLOB_RECURSE native_mps_h "native/mps/*.h")
file(GLOB_RECURSE native_sparse_mps_mm "native/sparse/mps/*.mm")
file(GLOB_RECURSE native_mps_sparse_metal "native/sparse/mps/*.metal")
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
file(GLOB native_quantized_cpp
@ -180,27 +178,26 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
if(USE_FLASH_ATTENTION)
if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1")
message(STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead")
caffe2_update_option(USE_ROCM_CK_SDPA ON)
endif()
if(USE_ROCM_CK_SDPA)
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
# FAv3 Generation
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
endif()
endif()
message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
# FAv3 Generation
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
endif()
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
@ -419,42 +416,40 @@ if(USE_CUDA)
endif()
if(USE_ROCM)
if((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM)
# NOTE: The PyTorch build does not actually add_subdirectory
# third_party/composable_kernel or use it as a CMake library. What is used
# is header only, so this should be ok, except that the CMake build generates
# a ck/config.h. We just do that part here. Without this, the ck.h from the
# ROCM SDK may get accidentally used instead.
function(_pytorch_rocm_generate_ck_conf)
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
set(CK_ENABLE_FP8 "ON")
set(CK_ENABLE_BF8 "ON")
set(CK_USE_XDL "ON")
set(CK_USE_WMMA "ON")
configure_file(
"${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
)
endfunction()
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
_pytorch_rocm_generate_ck_conf()
endif()
# NOTE: The PyTorch build does not actually add_subdirectory
# third_party/composable_kernel or use it as a CMake library. What is used
# is header only, so this should be ok, except that the CMake build generates
# a ck/config.h. We just do that part here. Without this, the ck.h from the
# ROCM SDK may get accidentally used instead.
function(_pytorch_rocm_generate_ck_conf)
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
set(CK_ENABLE_FP8 "ON")
set(CK_ENABLE_BF8 "ON")
set(CK_USE_XDL "ON")
set(CK_USE_WMMA "ON")
configure_file(
"${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
)
endfunction()
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
_pytorch_rocm_generate_ck_conf()
# Next two lines are needed because TunableOp uses third-party/fmt
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
if(USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
endif()
if(USE_FLASH_ATTENTION)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
endif()
list(APPEND ATen_HIP_SRCS
${ATen_HIP_SRCS}
${hip_hip}
@ -464,17 +459,12 @@ if(USE_ROCM)
${native_quantized_hip_hip}
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
)
if(NOT USE_ROCM_CK_GEMM)
if(WIN32) # Windows doesn't support Composable Kernels
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
file(GLOB native_hip_ck "native/hip/ck*.hip")
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
${native_hip_bgemm} ${native_hip_ck})
endif()
if(WIN32) # Windows doesn't support Composable Kernels and Triton
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
endif()
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
list(APPEND all_hip_cpp
${native_nested_hip_cpp}
@ -709,10 +699,10 @@ endif()
if(USE_MPS)
include(../../../cmake/Metal.cmake)
set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h} ${native_sparse_mps_mm})
set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h})
if(CAN_COMPILE_METAL)
foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal})
foreach(SHADER ${native_mps_metal})
cmake_path(GET SHADER STEM TGT_STEM)
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
list(APPEND AIR_BASIC ${TGT_BASIC})
@ -727,7 +717,7 @@ if(USE_MPS)
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
else()
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal})
foreach(SHADER ${native_mps_metal})
cmake_path(GET SHADER STEM TGT_STEM)
string(CONCAT SHADER_HDR_NAME "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h")
metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME})

View File

@ -480,9 +480,6 @@ at::BlasBackend Context::blasPreferredBackend() {
// call site for blasPreferredBackend(), we set it to an actual value.
if (blas_preferred_backend == at::BlasBackend::Default) {
blas_preferred_backend = at::BlasBackend::Cublas;
// This logic sits in the getter because it needs to validate
// values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT
// which initialize the backend without calling the setter
#ifdef USE_ROCM
// AMD Instinct targets prefer hipblaslt
static const bool hipblaslt_preferred = []() {
@ -512,10 +509,6 @@ at::BlasBackend Context::blasPreferredBackend() {
// hipblaslt support for all archs is not as complete as hipblas
if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
static const bool hipblaslt_unsupported = []() {
if(!hasCuBLASLt())
{
return true;
}
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
@ -541,24 +534,6 @@ at::BlasBackend Context::blasPreferredBackend() {
return blas_preferred_backend;
}
bool Context::ckSupported() {
#ifdef USE_ROCM
static const std::vector<std::string> supported_archs = {
"gfx90a", "gfx942", "gfx950"
};
for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) {
if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) {
TORCH_WARN_ONCE(
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
return false;
}
}
return true;
#else
return false;
#endif
}
void Context::setBlasPreferredBackend(at::BlasBackend b) {
#ifdef _MSC_VER
TORCH_WARN_ONCE(
@ -568,14 +543,8 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
#else
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
#ifdef USE_ROCM
static const bool ckSupportedFlag = ckSupported();
static const bool hasCKGEMMFlag = hasCKGEMM();
TORCH_CHECK((b != at::BlasBackend::Ck) || (ckSupportedFlag && hasCKGEMMFlag),
"Cannot set preferred blas backend to CK since following conditions are not true: ",
"architecture supported for CK: ", ckSupportedFlag,
", PyTorch built with CK GEMM support: ", hasCKGEMMFlag);
#endif
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) {
TORCH_WARN_ONCE(
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
@ -587,40 +556,35 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
#endif
}
at::ROCmFABackend Context::getROCmFAPreferredBackend() {
#ifdef USE_ROCM
// Set potential "Default" value so we don't have to interpret at call sites.
// We use aotriton backend as the default, for now.
if(rocm_fa_preferred_backend == at::ROCmFABackend::Default) {
rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
} else if (rocm_fa_preferred_backend == at::ROCmFABackend::Ck) {
// This logic sits in the getter because it needs to validate
// values set via env vars such as TORCH_ROCM_FA_PREFER_CK
// which initialize the backend without calling the setter
// Perform validity checking
static const bool hasCKSDPAFlag = hasCKSDPA();
static const bool ckSupportedFlag = ckSupported();
if(!(hasCKSDPAFlag && ckSupportedFlag)){
TORCH_WARN_ONCE(
"Cannot set preferred SDPA backend to CK since following conditions are not true: ",
"architecture supported for CK: ", ckSupportedFlag,
", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
}
}
#endif
at::ROCmFABackend Context::getROCmFAPreferredBackend() const {
return rocm_fa_preferred_backend;
}
void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
// TODO: add plumbing for hasCK for validity checking
TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(),
"Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm.");
#ifdef USE_ROCM
static const bool hasCKSDPAFlag = hasCKSDPA();
static const bool ckSupportedFlag = ckSupported();
TORCH_CHECK((b != at::ROCmFABackend::Ck) || (hasCKSDPAFlag && ckSupportedFlag),
"Cannot set preferred SDPA backend to CK since following conditions are not true: ",
"architecture supported for CK: ", ckSupportedFlag,
", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
if(b == at::ROCmFABackend::Ck) {
static const bool ck_unsupported = []() {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942"
};
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
TORCH_WARN_ONCE(
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
return true;
}
}
return false;
}();
if(!ck_unsupported) rocm_fa_preferred_backend = b;
}
else {
rocm_fa_preferred_backend = b;
}
#endif
rocm_fa_preferred_backend = b;
}

View File

@ -132,7 +132,6 @@ class TORCH_API Context {
static bool hasKleidiAI();
static bool hasLAPACK();
static bool hasMKLDNN();
static bool ckSupported();
static bool hasMAGMA() {
return detail::getCUDAHooks().hasMAGMA();
}
@ -163,12 +162,6 @@ class TORCH_API Context {
static bool hasROCM() {
return detail::getCUDAHooks().hasROCM();
}
static bool hasCKSDPA() {
return detail::getCUDAHooks().hasCKSDPA();
}
static bool hasCKGEMM() {
return detail::getCUDAHooks().hasCKGEMM();
}
static bool hasHIP() {
return detail::getHIPHooks().hasHIP();
}
@ -259,7 +252,7 @@ class TORCH_API Context {
at::BlasBackend blasPreferredBackend();
void setBlasPreferredBackend(at::BlasBackend);
at::ROCmFABackend getROCmFAPreferredBackend();
at::ROCmFABackend getROCmFAPreferredBackend() const;
void setROCmFAPreferredBackend(at::ROCmFABackend);
// Note [Enabling Deterministic Operations]

View File

@ -31,9 +31,7 @@ c10::Allocator* GetCPUAllocatorMaybePinned(bool pin_memory) {
return at::globalContext().getPinnedMemoryAllocator(opt_device_type);
} else {
TORCH_CHECK(
false,
"pin_memory=True requires a CUDA or other accelerator backend; "
"no pinned memory allocator is available on this system.")
false, "Need to provide pin_memory allocator to use pin memory.")
}
}

View File

@ -239,7 +239,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
KERNEL_MPS(scaled_dot_product_attention, lower_precision_fp)
// fp32
KERNEL_MPS(conv_transpose3d, input, fp32)
KERNEL_MPS(acos, fp32)
KERNEL_MPS(asin, fp32)
KERNEL_MPS(cosh, fp32)

View File

@ -97,8 +97,6 @@ c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
return ComplexType::get();
case Tag::Int:
return IntType::get();
case Tag::UInt:
return IntType::get();
case Tag::SymInt:
return c10::SymIntType::get();
case Tag::SymFloat:
@ -322,8 +320,6 @@ IValue IValue::equals(const IValue& rhs) const {
return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
case Tag::Int:
return rhs.isInt() && lhs.toInt() == rhs.toInt();
case Tag::UInt:
return rhs.isUnsigned() && lhs.toUInt() == rhs.toUInt();
case Tag::SymInt:
return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt();
case Tag::SymFloat:
@ -383,8 +379,6 @@ size_t IValue::hash(const IValue& v) {
case Tag::Int:
return c10::get_hash(v.payload.u.as_int);
// NB: these are technically strict aliasing violations
case Tag::UInt:
return c10::get_hash(v.payload.u.as_int);
case Tag::SymInt:
return c10::get_hash(v.payload.u.as_int);
case Tag::SymFloat:
@ -812,8 +806,6 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) {
return printComplex(out, v);
} case IValue::Tag::Int:
return out << v.toInt();
case IValue::Tag::UInt:
return out << v.toUInt();
case IValue::Tag::SymInt:
return out << v.toSymInt();
case IValue::Tag::SymFloat:

View File

@ -12,7 +12,6 @@
#include <c10/macros/Export.h>
#include <c10/util/MaybeOwned.h>
#include <c10/util/intrusive_ptr.h>
#include <limits>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
@ -161,7 +160,6 @@ struct Capsule {
_(Double) \
_(ComplexDouble) \
_(Int) \
_(UInt) \
_(SymInt) \
_(SymFloat) \
_(SymBool) \
@ -655,29 +653,6 @@ struct TORCH_API IValue final {
}
}
// Unsigned
IValue(uint64_t u) : tag( u <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt) {
payload.u.as_uint = u;
}
// See Note [Meaning of HAS_u]
// IValue type model closely follows that of c10::Scalar
// Where all integers are upcast to 64-bit representation, and `as_int` is used as default
// representation unless value could not be represented as signed int
bool isUnsigned() const {
return Tag::UInt == tag || (Tag::Int == tag && payload.u.as_int >= 0);
}
uint64_t toUInt() const {
if (isUnsigned()) {
return payload.u.as_uint;
} else {
TORCH_INTERNAL_ASSERT(0, "expected unsigned int");
}
}
// Bool
IValue(bool b) : tag(Tag::Bool) {
#if defined(__clang__) && defined(__x86_64__)
@ -918,14 +893,8 @@ struct TORCH_API IValue final {
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
s.isIntegral(false), "Unknown type in Scalar");
if (s.isUnsigned()) {
const auto val = s.toUInt64();
payload.u.as_uint = val;
tag = val <= std::numeric_limits<int64_t>::max() ? Tag::Int : Tag::UInt;
} else {
payload.u.as_int = s.toLong();
tag = Tag::Int;
}
tag = Tag::Int;
payload.u.as_int = s.toLong();
}
}
@ -949,8 +918,6 @@ struct TORCH_API IValue final {
return toSymFloat();
else if (isSymBool())
return toSymBool();
else if (isUnsigned())
return toUInt();
TORCH_CHECK(false, "IValue is not a Scalar");
}
@ -1280,8 +1247,6 @@ struct TORCH_API IValue final {
return true;
case Tag::Int:
return false;
case Tag::UInt:
return false;
case Tag::SymInt:
return true;
case Tag::SymFloat:
@ -1378,8 +1343,6 @@ struct TORCH_API IValue final {
union TriviallyCopyablePayload {
TriviallyCopyablePayload() : as_int(0) {}
int64_t as_int;
// See Note [Meaning of HAS_u]
uint64_t as_uint;
double as_double;
bool as_bool;
// Invariant: never nullptr; null state is represented as

View File

@ -832,7 +832,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
}
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
#if defined(USE_ROCM) && !defined(_MSC_VER)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
@ -1273,7 +1273,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
#endif
}
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
#if defined(USE_ROCM) && !defined(_MSC_VER)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
}
@ -1289,7 +1289,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
#if defined(USE_ROCM) && !defined(_MSC_VER)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
@ -1341,7 +1341,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
#if defined(USE_ROCM) && !defined(_MSC_VER)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
@ -1357,7 +1357,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
#if defined(USE_ROCM) && !defined(_MSC_VER)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}

View File

@ -207,27 +207,6 @@ bool CUDAHooks::hasCuBLASLt() const {
#endif
}
bool CUDAHooks::hasCKSDPA() const {
#if !defined(USE_ROCM)
return false;
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_SDPA)
return true;
#else
return false;
#endif
}
bool CUDAHooks::hasCKGEMM() const {
#if !defined(USE_ROCM)
return false;
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
return true;
#else
return false;
#endif
}
bool CUDAHooks::hasROCM() const {
// Currently, this is same as `compiledWithMIOpen`.
// But in future if there are ROCm builds without MIOpen,

View File

@ -31,8 +31,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCuSOLVER() const override;
bool hasCuBLASLt() const override;
bool hasROCM() const override;
bool hasCKSDPA() const override;
bool hasCKGEMM() const override;
const at::cuda::NVRTC& nvrtc() const override;
DeviceIndex current_device() const override;
bool isBuilt() const override {return true;}

View File

@ -118,14 +118,6 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
return false;
}
virtual bool hasCKSDPA() const {
return false;
}
virtual bool hasCKGEMM() const {
return false;
}
virtual const at::cuda::NVRTC& nvrtc() const {
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
}

View File

@ -21,10 +21,6 @@ bool isMTIAHooksBuilt() {
} // namespace detail
bool MTIAHooksInterface::isAvailable() const {
return detail::isMTIAHooksBuilt() && detail::getMTIAHooks().deviceCount() > 0;
}
C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs)
} // namespace at

View File

@ -149,8 +149,6 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
FAIL_MTIAHOOKS_FUNC(__func__);
return;
}
virtual bool isAvailable() const override;
};
struct TORCH_API MTIAHooksArgs {};

View File

@ -43,6 +43,7 @@ TensorBase empty_mps(
int64_t nelements = c10::multiply_integers(size);
auto dtype = dtype_or_default(dtype_opt);
TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
TORCH_CHECK_TYPE(dtype != ScalarType::BFloat16 || is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_14_0_PLUS), "MPS BFloat16 is only supported on MacOS 14 or newer");
auto dtype_meta = scalarTypeToTypeMeta(dtype);

View File

@ -18,7 +18,11 @@ namespace at::mps {
// Helper enum to check if a MPSGraph op is supported in a given macOS version
enum class MacOSVersion : uint32_t {
MACOS_VER_14_4_PLUS = 0,
MACOS_VER_13_1_PLUS = 0,
MACOS_VER_13_2_PLUS,
MACOS_VER_13_3_PLUS,
MACOS_VER_14_0_PLUS,
MACOS_VER_14_4_PLUS,
MACOS_VER_15_0_PLUS,
MACOS_VER_15_1_PLUS,
MACOS_VER_15_2_PLUS,

View File

@ -32,11 +32,11 @@ MPSDevice::~MPSDevice() {
MPSDevice::MPSDevice() : _mtl_device(nil) {
// Check that MacOS 13.0+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 14.0
// Create the MPSGraph and check method introduced in 13.0
// which is used by MPS backend.
id mpsCD = NSClassFromString(@"MPSGraph");
if ([mpsCD instancesRespondToSelector:@selector(HermiteanToRealFFTWithTensor:axes:descriptor:name:)] == NO) {
if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) {
return;
}
@ -66,12 +66,24 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}];
}
};
static bool _macos_13_1_plus = is_os_version_at_least(13, 1);
static bool _macos_13_2_plus = is_os_version_at_least(13, 2);
static bool _macos_13_3_plus = is_os_version_at_least(13, 3);
static bool _macos_14_0_plus = is_os_version_at_least(14, 0);
static bool _macos_14_4_plus = is_os_version_at_least(14, 4);
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);
static bool _macos_15_1_plus = is_os_version_at_least(15, 1);
static bool _macos_15_2_plus = is_os_version_at_least(15, 2);
switch (version) {
case MacOSVersion::MACOS_VER_13_1_PLUS:
return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS:
return _macos_13_2_plus;
case MacOSVersion::MACOS_VER_13_3_PLUS:
return _macos_13_3_plus;
case MacOSVersion::MACOS_VER_14_0_PLUS:
return _macos_14_0_plus;
case MacOSVersion::MACOS_VER_14_4_PLUS:
return _macos_14_4_plus;
case MacOSVersion::MACOS_VER_15_0_PLUS:

View File

@ -34,7 +34,7 @@ bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
case 14:
switch (minor) {
case 0:
return true;
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
case 4:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
default:
@ -42,7 +42,19 @@ bool MPSHooks::isOnMacOSorNewer(unsigned major, unsigned minor) const {
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS);
}
case 13:
return true;
switch (minor) {
case 0:
return true;
case 1:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
case 2:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
case 3:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
default:
TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+");
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
}
default:
TORCH_WARN("Checking for unexpected MacOS ", major, ".", minor, " returning false");
return false;

View File

@ -51,7 +51,7 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *
// brgemm_pack_B is changed to transform and the setting of brgemm beta is changed to set_add_C
#if (IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR == 5)
#define ONEDNN_UKERNEL_1
#elif ((IDEEP_VERSION_MAJOR == 3 && IDEEP_VERSION_MINOR >= 6) || (IDEEP_VERSION_MAJOR > 3))
#elif (IDEEP_VERSION_MAJOR >= 3 && IDEEP_VERSION_MINOR >= 6)
#define ONEDNN_UKERNEL_2
#endif
#if ((defined(ONEDNN_UKERNEL_1) || defined(ONEDNN_UKERNEL_2)) && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))))

View File

@ -206,16 +206,6 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
// B Base pointer to a tensor B.
// C Pointer to a tensor C (accumulation buffer).
// Note only batch size 1 is used currently
// Define macros for available brgemm APIs
// so that callers can determine which APIs are available
#define CPUBLAS_BRGEMM_F16F16F32 // half * half -> float
#define CPUBLAS_BRGEMM_BF16BF16F32 // bfloat16 * bfloat16 -> float
#define CPUBLAS_BRGEMM_F32F32F32 // float * float -> float
#define CPUBLAS_BRGEMM_U8U8I32 // unsigned char * unsigned char -> int32
#define CPUBLAS_BRGEMM_U8I8I32 // unsigned char * signed char -> int32
#define CPUBLAS_BRGEMM_I8I8I32 // signed char * signed char -> int32
TORCH_API void brgemm(
int64_t M,
int64_t N,

View File

@ -3,7 +3,6 @@
#include <ATen/Config.h>
#include <ATen/Parallel.h>
#include <ATen/TensorOperators.h>
#include <ATen/native/CanUse32BitIndexMath.h>
#include <ATen/native/ConvolutionMM3d.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/Pool.h>
@ -464,7 +463,7 @@ struct ConvParams {
return true;
}
// native kernel doesn't support 64-bit non-splittable case
if (cudnn_enabled && !(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
if (cudnn_enabled && needs_64bit_indexing_no_split(input, weight)) {
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
if (!(cudnn_version >= 90300 && at::native::cudnnv8_enabled_check_debug())) {
TORCH_WARN_ONCE("cuDNN cannot be used for large non-batch-splittable convolutions"

View File

@ -282,14 +282,6 @@ __device__ __forceinline__ void opportunistic_fastAtomicAdd(
}
// not coalsced, so now let try to capture lane-matches...
if (numel > 16 /*<-hueristic threshold*/ * 64 ) {
// well shucks, unlikely to capture same-dest atomics in a wave.
// fall back to direct fastAtomic...
fastAtomicAdd(self_ptr, index, numel, value, true);
return;
}
// __activemask() -- finds the set of threads in the warp that are about to perform atomicAdd
// __match_any_sync() -- returns bit mask of the threads that have same dest addr
auto mask = __match_any_sync(__activemask(), (int64_t)dst);

File diff suppressed because it is too large Load Diff

View File

@ -70,31 +70,4 @@ void run_cudnn_SDP_bprop(
const Tensor& dropoutseed,
const Tensor& dropoutoffset);
void run_cudnn_SDP_bprop_nestedtensor(
int64_t b,
int64_t h_q,
int64_t h_k,
int64_t h_v,
int64_t s_q,
int64_t s_kv,
int64_t d_qk,
int64_t d_v,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& cum_seqlen_q,
const Tensor& cum_seqlen_kv,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const std::optional<Tensor>& attn_bias,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset);
} // namespace at::native

View File

@ -10,7 +10,6 @@ inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented");
}
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
template <>
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
@ -19,7 +18,7 @@ template <>
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif
} // namespace at::native

View File

@ -1,7 +1,6 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#if defined(USE_ROCM_CK_GEMM)
#include <ATen/native/hip/ck_gemm.h>
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
@ -782,4 +781,3 @@ void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
} // namespace at::native
#endif // USE_ROCM_CK_GEMM

View File

@ -1,7 +1,6 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#if defined(USE_ROCM_CK_GEMM)
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
@ -485,4 +484,3 @@ void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
}
} // namespace at::native
#endif // USE_ROCM_CK_GEMM

View File

@ -1,7 +1,6 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#if defined(USE_ROCM_CK_GEMM)
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
@ -607,4 +606,3 @@ void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
}
} // namespace at::native
#endif // USE_ROCM_CK_GEMM

View File

@ -88,8 +88,14 @@ std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input);
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input);
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64 = false);
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64 = false);
MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray);
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
@ -429,6 +435,14 @@ inline T* LookUpOrCreateCachedGraph(const std::string& key, std::function<void(M
// Common math operations
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
TORCH_WARN_ONCE( \
"MPS: no support for int64 for ", \
op_name, \
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
}
/**
* Returns distance from lowest to highest element offset in given tensor.
*/
@ -604,6 +618,10 @@ inline void runMPSGraph(MPSStream* stream, MPSGraph* graph, NSDictionary* feeds,
runMPSGraph(stream, graph, feeds, dictionaryFromPlaceholders(result));
}
inline bool supportsComplex() {
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS);
}
// MPS yet to support double types, but starting from MacOS 14, supports bfloat16
inline bool supportedFloatingType(ScalarType dtype) {
return dtype == kFloat || dtype == kHalf || dtype == kBFloat16;
@ -615,7 +633,7 @@ inline bool supportedFloatingType(const TensorBase& t) {
inline bool supportedFloatingOrComplexType(ScalarType dtype) {
if (dtype == kComplexFloat || dtype == kComplexHalf) {
return true;
return supportsComplex();
}
return supportedFloatingType(dtype);
}
@ -623,6 +641,11 @@ inline bool supportedFloatingOrComplexType(const TensorBase& t) {
return supportedFloatingOrComplexType(t.scalar_type());
}
inline void checkSupportsBFloat16() {
TORCH_CHECK_TYPE(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS),
"MPS bfloat16 type is supported on MacOS 14.0 or newer.");
}
inline bool needsGather(const TensorBase& t) {
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset());

View File

@ -89,6 +89,10 @@ void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds,
mpsStream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT_ADAPTIVE);
}
static inline void checkSupportsComplex() {
TORCH_CHECK_TYPE(supportsComplex(), "MPS complex types are only supported on MacOS 14.0 or newer.");
}
MPSDataType getMPSDataType(ScalarType scalar_type) {
switch (scalar_type) {
case ScalarType::Float:
@ -96,6 +100,7 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
case ScalarType::Half:
return MPSDataTypeFloat16;
case ScalarType::BFloat16:
checkSupportsBFloat16();
return MPSDataTypeBFloat16;
case ScalarType::Int:
return MPSDataTypeInt32;
@ -114,8 +119,10 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
"Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
"Please use float32 instead.")
case ScalarType::ComplexHalf:
checkSupportsComplex();
return MPSDataTypeComplexFloat16;
case ScalarType::ComplexFloat:
checkSupportsComplex();
return MPSDataTypeComplexFloat32;
// Unsigned types
case ScalarType::UInt64:
@ -133,10 +140,16 @@ MPSDataType getMPSDataType(ScalarType scalar_type) {
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
// Int32, Half and Float32 types. These utilities are to help cast to these
// types.
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) {
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64) {
MPSDataType dataType = getMPSDataType(input.scalar_type());
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) &&
(dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64);
bool condition =
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
if (includesInt64) {
condition = condition && (dataType != MPSDataTypeInt64);
}
if (condition) {
dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
return [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
@ -147,10 +160,16 @@ MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor,
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
// Int32, Half and Float32 types. These utilities are to help cast from these
// types.
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input) {
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64) {
MPSDataType dataType = getMPSDataType(input.scalar_type());
bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) &&
(dataType != MPSDataTypeFloat16) && (dataType != MPSDataTypeInt64);
bool condition =
(dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
if (includesInt64) {
condition = condition && (dataType != MPSDataTypeInt64);
}
if (condition) {
inputTensor = [mpsGraph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
}
@ -167,6 +186,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
case ScalarType::Half:
return MPSDataTypeFloat16;
case ScalarType::BFloat16:
checkSupportsBFloat16();
return MPSDataTypeBFloat16;
case ScalarType::Int:
return MPSDataTypeInt32;
@ -181,11 +201,13 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
case ScalarType::Bool:
return MPSDataTypeBool;
case ScalarType::ComplexHalf:
checkSupportsComplex();
return MPSDataTypeComplexFloat16;
// This is an intentional fallthrough supporting ComplexDouble for Scalar
// types as they are casted to Complex64 currently.
case ScalarType::ComplexDouble:
case ScalarType::ComplexFloat:
checkSupportsComplex();
return MPSDataTypeComplexFloat32;
// Unsigned types
case ScalarType::UInt64:
@ -245,6 +267,7 @@ std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type) {
case ScalarType::Half:
return "half";
case ScalarType::BFloat16:
checkSupportsBFloat16();
return "bfloat";
case ScalarType::Int:
return "int";
@ -856,7 +879,9 @@ id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
MTLCompileOptions* options = compile_options;
if (!options) {
options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion3_1];
// Need 3.0 for atomic oprations, 3.1 introduces bfloat support
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion3_0];
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) {
options.mathMode = fast_math ? MTLMathModeFast : MTLMathModeSafe;
options.mathFloatingPointFunctions =

View File

@ -5,6 +5,29 @@
using namespace metal;
using namespace c10::metal;
namespace c10 {
namespace metal {
// There are no atomic 64-bit add in Metal yet, but this implements a consistent
// add I.e. if multiple threads are modify the same 64-bit value, results stored
// at the address will eventually be equal to its original value plus sum of all
// operands
template <>
struct AtomicType<long> {
using type = ::metal::atomic<uint>;
static inline void atomic_add(device type* data, long offset, long value) {
const auto value_bits = as_type<ulong>(value);
const uint low = static_cast<uint>(value_bits);
uint high = static_cast<uint>(value_bits >> 32);
auto ptr = data + (offset << 1);
auto old_low = atomic_fetch_add_explicit(ptr, low, memory_order_relaxed);
high += (old_low + low < old_low) ? 1 : 0;
atomic_fetch_add_explicit(ptr + 1, high, memory_order_relaxed);
}
};
} // namespace metal
} // namespace c10
struct IndexAB {
constant int64_t* indexArray;
};
@ -211,15 +234,13 @@ REGISTER_INDEX_OP_ALL_DTYPES(put_serial);
REGISTER_INDEX_OP(put_accumulate, float, float);
REGISTER_INDEX_OP(put_accumulate, half, half);
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
REGISTER_INDEX_OP(put_accumulate, long, long);
REGISTER_INDEX_OP(put_accumulate, int, int);
REGISTER_INDEX_OP(put_accumulate, short, short);
REGISTER_INDEX_OP(put_accumulate, char, char);
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
REGISTER_INDEX_OP(put_accumulate, bool, bool);
REGISTER_INDEX_OP(put_accumulate, float2, float2);
REGISTER_INDEX_OP(put_accumulate, half2, half2);
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
template <typename StridesT, typename DataT>
kernel void kernel_index_offsets(

View File

@ -68,37 +68,6 @@ kernel void matmul(
}
}
template <typename T>
kernel void addmm(
constant T* mat1Data [[buffer(0)]],
constant T* mat2Data [[buffer(1)]],
device T* outputData [[buffer(2)]],
constant T* biasData [[buffer(3)]],
constant array<c10::metal::opmath_t<T>, 2>& alpha_beta [[buffer(4)]],
constant array<ulong2, 4>& strides [[buffer(5)]],
constant uint3& sizes [[buffer(6)]],
uint2 tid [[thread_position_in_threadgroup]],
uint2 thread_id [[thread_position_in_grid]]) {
threadgroup T A_tile[TILE_DIM][TILE_DIM];
threadgroup T B_tile[TILE_DIM][TILE_DIM];
auto sum = matmul_inner<T>(
mat1Data,
mat2Data,
reinterpret_cast<constant array<ulong2, 3>&>(strides),
sizes,
A_tile,
B_tile,
tid,
thread_id);
if (thread_id.y < sizes.x && thread_id.x < sizes.z) {
auto bias =
biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y];
outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] =
static_cast<T>(alpha_beta[0] * sum + alpha_beta[1] * bias);
}
}
template <typename T>
kernel void naive_bmm(
constant T* mat1Data [[buffer(0)]],
@ -644,15 +613,17 @@ kernel void applyPivots(
}
}
#define INSTANTIATE_MM_OPS(DTYPE) \
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
constant DTYPE * mat1Data [[buffer(0)]], \
constant DTYPE * mat2Data [[buffer(1)]], \
device DTYPE * outputData [[buffer(2)]], \
constant array<ulong2, 3> & strides [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 tid [[thread_position_in_threadgroup]], \
uint2 group_id [[threadgroup_position_in_grid]]); \
#define INSTANTIATE_NAIVE_MM(DTYPE) \
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
constant DTYPE * mat1Data [[buffer(0)]], \
constant DTYPE * mat2Data [[buffer(1)]], \
device DTYPE * outputData [[buffer(2)]], \
constant array<ulong2, 3> & strides [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 tid [[thread_position_in_threadgroup]], \
uint2 group_id [[threadgroup_position_in_grid]])
#define INSTANTIATE_NAIVE_BMM(DTYPE) \
template [[host_name("naive_bmm_" #DTYPE)]] kernel void naive_bmm<DTYPE>( \
constant DTYPE * mat1Data [[buffer(0)]], \
constant DTYPE * mat2Data [[buffer(1)]], \
@ -660,26 +631,20 @@ kernel void applyPivots(
constant array<ulong, 9> & strides [[buffer(3)]], \
constant uint4 & sizes [[buffer(4)]], \
uint3 tid [[thread_position_in_threadgroup]], \
uint3 group_id [[threadgroup_position_in_grid]]); \
template [[host_name("addmm_" #DTYPE)]] kernel void addmm<DTYPE>( \
constant DTYPE * mat1Data [[buffer(0)]], \
constant DTYPE * mat2Data [[buffer(1)]], \
device DTYPE * outputData [[buffer(2)]], \
constant DTYPE * biasData [[buffer(3)]], \
constant array<c10::metal::opmath_t<DTYPE>, 2> & \
alpha_beta [[buffer(4)]], \
constant array<ulong2, 4> & strides [[buffer(5)]], \
constant uint3 & sizes [[buffer(6)]], \
uint2 tid [[thread_position_in_threadgroup]], \
uint2 group_id [[threadgroup_position_in_grid]])
uint3 group_id [[threadgroup_position_in_grid]])
INSTANTIATE_MM_OPS(float);
INSTANTIATE_MM_OPS(half);
INSTANTIATE_MM_OPS(bfloat);
INSTANTIATE_NAIVE_MM(float);
INSTANTIATE_NAIVE_MM(half);
INSTANTIATE_NAIVE_MM(bfloat);
// Integral MM
INSTANTIATE_MM_OPS(long);
INSTANTIATE_MM_OPS(int);
INSTANTIATE_MM_OPS(short);
INSTANTIATE_MM_OPS(char);
INSTANTIATE_MM_OPS(uchar);
INSTANTIATE_NAIVE_MM(short);
INSTANTIATE_NAIVE_MM(int);
INSTANTIATE_NAIVE_MM(long);
INSTANTIATE_NAIVE_MM(char);
INSTANTIATE_NAIVE_MM(uchar);
INSTANTIATE_NAIVE_BMM(short);
INSTANTIATE_NAIVE_BMM(int);
INSTANTIATE_NAIVE_BMM(long);
INSTANTIATE_NAIVE_BMM(char);
INSTANTIATE_NAIVE_BMM(uchar);

View File

@ -88,53 +88,6 @@ void max_pool_3d_input_iter(
}
}
template <typename T, bool return_indices>
void max_pool_2d_input_iter(
constant T* input,
device T* output,
device int64_t* indices,
constant int32_t* input_sizes,
constant int32_t* input_strides,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
constant int32_t* dilation) {
auto bounds0 = get_input_iter_bounds<0>(
input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
auto bounds1 = get_input_iter_bounds<1>(
input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
auto d0 = dilation[0];
auto d1 = dilation[1];
T max_value = input
[input_strides[0] * bounds0.start + input_strides[1] * bounds1.start];
auto max_index = bounds0.start * input_sizes[1] + bounds1.start;
for (auto i0 = bounds0.start; i0 < bounds0.end; i0 += d0) {
auto offset0 = input_strides[0] * i0;
for (auto i1 = bounds1.start; i1 < bounds1.end; i1 += d1) {
auto offset1 = input_strides[1] * i1;
auto input_value = input[offset0 + offset1];
bool is_greater = input_value > max_value;
max_value = is_greater ? input_value : max_value;
if (return_indices) {
auto input_index = i0 * input_sizes[1] + i1;
max_index = is_greater ? input_index : max_index;
}
}
}
*output = max_value;
if (return_indices) {
*indices = max_index;
}
}
struct PoolOffsets {
int32_t output;
int32_t indices;
@ -259,7 +212,7 @@ kernel void max_pool(
PoolOffsets offsets = find_pool_offsets(
output_sizes,
output_strides,
return_indices ? indices_strides : nullptr,
indices_strides,
input_strides,
pooling_dim_indices,
dims,
@ -271,47 +224,18 @@ kernel void max_pool(
indices += offsets.indices;
input += offsets.input_leading;
switch (pooling_dims) {
case 2:
if (return_indices) {
return max_pool_2d_input_iter<T, /*return_indices=*/true>(
input,
output,
indices,
input_sizes + leading_dims,
input_strides + leading_dims,
pooling_dim_indices,
kernel_size,
stride,
padding,
dilation);
} else {
return max_pool_2d_input_iter<T, /*return_indices=*/false>(
input,
output,
indices,
input_sizes + leading_dims,
input_strides + leading_dims,
pooling_dim_indices,
kernel_size,
stride,
padding,
dilation);
}
case 3:
return max_pool_3d_input_iter<T>(
input,
output,
indices,
input_sizes + leading_dims,
input_strides + leading_dims,
pooling_dim_indices,
kernel_size,
stride,
padding,
dilation,
return_indices);
}
max_pool_3d_input_iter<T>(
input,
output,
indices,
input_sizes + leading_dims,
input_strides + leading_dims,
pooling_dim_indices,
kernel_size,
stride,
padding,
dilation,
return_indices);
}
// Finds the element in the grad input which corresponds to the index into the

View File

@ -53,7 +53,6 @@ void binary_op_kernel(const std::string func_name,
.add_input(input)
.add_input(other)
.check_all_same_dtype(false)
.promote_inputs_to_common_dtype(true)
.build();
lib.exec_binary_kernel(iter, func_name, alpha);

View File

@ -48,11 +48,28 @@ typedef MPSGraphTensor* (^BinaryOpBlock)(BinaryOpCachedGraph*, MPSGraphTensor*,
#define BinaryOpFn(graph, primary, secondary) \
MPSGraphTensor*(mps::BinaryOpCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary)
static inline Tensor legacy_complex_as_view(const Tensor& t) {
// Convert non-complex types (and cdouble CPU scalars) to cfloat
if (!isComplexType(t.scalar_type()) || t.scalar_type() == kComplexDouble) {
return at::view_as_real(t.to(kMPS, kComplexFloat));
}
return at::view_as_real(t.dim() != 0 ? t : t.to(kMPS));
}
static void binaryOpTensor(const Tensor& self,
const Tensor& other,
const Tensor& output_,
std::string op_name,
BinaryOpBlock binaryBlock) {
TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) &&
(self.scalar_type() == ScalarType::Long ||
(other.scalar_type() == ScalarType::Long &&
(self.scalar_type() != ScalarType::Half && self.scalar_type() != ScalarType::Float)))),
"MPS: ",
op_name,
" op with int64 input is supported natively starting from macOS 13.2");
TORCH_CHECK_TYPE(!isComplexType(self.scalar_type()) || mps::supportsComplex(),
"Complex types are supported starting from MacOS 14.0+");
MPSStream* mpsStream = getCurrentMPSStream();
const bool is_self_scalar = self.dim() == 0;

View File

@ -51,6 +51,9 @@ inline void dot_check(const Tensor& self, const Tensor& other) {
} // namespace mps
Tensor dot_mps(const Tensor& self, const Tensor& other) {
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || self.scalar_type() != ScalarType::Long,
"MPS: dot op doesn't support int64 input on MacOS13")
using namespace mps;
using CachedGraph = MPSBinaryCachedGraph;

View File

@ -124,6 +124,7 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
IntArrayRef dilation,
int64_t groups,
std::optional<IntArrayRef> input_shape) {
const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
Tensor input_t = input_t_;
bool is3DConv = input_t.dim() == 5;
@ -131,6 +132,9 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
input_t = input_t.contiguous();
}
TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer),
"Conv3D is only supported on MPS for MacOS_13_2 or newer");
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
using namespace at::native::mps;

View File

@ -60,6 +60,7 @@ static void copy_cast_mps(at::Tensor& dst,
outputTensor = [mpsGraph castTensor:outputTensor toType:dstDType name:@"cast"];
}
if (needs_conj) {
TORCH_CHECK(supportsComplex(), "MPS complex tensors conjugation needs MacOS14+");
outputTensor = [mpsGraph conjugateWithTensor:outputTensor name:nil];
}
@ -274,7 +275,24 @@ static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, boo
// for GPU to GPU copies we only encode to stream's command buffer (no flushing)
stream->copy(sourceBuffer, destBuffer, src.nbytes(), src_byte_offset, dst_byte_offset, profile_id);
} else {
if (dst_byte_offset) {
// Simulate cast to Complex on older MacOS by initializing real and imag parts
if (dst_.is_complex() && !supportsComplex()) {
if (!src.is_complex()) {
at::real(dst_).copy_(src);
at::imag(dst_).fill_(0);
} else if (src.is_conj() || dst_.is_conj()) {
// One cannot take view of conjugated tensor, but for some reason real and imag views are fine
// Use this to implement a conjugation
at::real(dst_).copy_(at::real(src));
if (src.is_conj() != dst_.is_conj()) {
at::imag(dst_).copy_(at::neg(at::imag(src)));
} else {
at::imag(dst_).copy_(at::imag(src));
}
} else {
at::view_as_real(dst_).copy_(at::view_as_real(src));
}
} else if (dst_byte_offset) {
auto maybeCastedSource =
at::empty(dst_.sizes(), dst_.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
auto maybeCastedSourceBuffer = getMTLBufferStorage(maybeCastedSource);

View File

@ -87,6 +87,7 @@ Tensor& random_mps_impl(Tensor& self,
case kFloat:
return MPSDataTypeFloat32;
case kBFloat16: {
checkSupportsBFloat16();
return MPSDataTypeBFloat16;
}
default:

View File

@ -88,6 +88,7 @@ using namespace mps;
// TODO: Investigate numerical discrepancies see https://github.com/pytorch/pytorch/issues/120237
Tensor& _fft_r2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) {
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
auto key = __func__ + getTensorsStringKey({self, out}) + ":" + getArrayRefString(dim) + ":" +
std::to_string(normalization) + ":" + std::to_string(onesided);
@autoreleasepool {
@ -128,6 +129,7 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
int64_t normalization,
int64_t last_dim_size,
Tensor& out) {
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
std::to_string(normalization) + ":" + std::to_string(last_dim_size);
@autoreleasepool {
@ -153,6 +155,7 @@ Tensor& _fft_c2r_mps_out(const Tensor& self,
}
Tensor& _fft_c2c_mps_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) {
TORCH_CHECK(supportsComplex(), "FFT operations are only supported on MacOS 14+");
auto key = __func__ + getTensorsStringKey({self}) + ":" + getArrayRefString(dim) + ":" +
std::to_string(normalization) + ":" + std::to_string(forward);
@autoreleasepool {

View File

@ -127,6 +127,15 @@ Tensor grid_sampler_2d_mps(const Tensor& input,
int64_t interpolation_mode,
int64_t padding_mode,
bool align_corners) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS)) {
TORCH_WARN_ONCE("MPS: grid_sampler_2d op is supported natively starting from macOS 13.2. ",
"Falling back on CPU. This may have performance implications.");
return at::grid_sampler_2d(input.to("cpu"), grid.to("cpu"), interpolation_mode, padding_mode, align_corners)
.clone()
.to("mps");
}
auto in_size = input.sizes();
auto grid_size = grid.sizes();
auto output = at::empty({in_size[0], in_size[1], grid_size[1], grid_size[2]}, input.options());

View File

@ -108,12 +108,26 @@ static std::string getBitSizeString(const TensorBase& t) {
static void validateInputData(const TensorIteratorBase& iter,
IntArrayRef index_size,
IntArrayRef index_stride,
const std::string& op) {
const std::string& op,
bool accumulate) {
using namespace mps;
const auto num_indices = index_size.size();
TORCH_CHECK(num_indices <= 16, "Current limit allows up to 16 indices to be used in MPS indexing kernels");
AT_ASSERT(num_indices == index_stride.size());
AT_ASSERT(static_cast<int>(num_indices) == iter.ntensors() - 2);
const Tensor& inputTensor = iter.tensor(1);
const auto scalar_type = inputTensor.scalar_type();
if (accumulate) {
// No atomic support for the complex dtypes
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type));
} else {
TORCH_CHECK(c10::isIntegralType(scalar_type, /*includesBool=*/true) || supportedFloatingType(scalar_type) ||
scalar_type == ScalarType::ComplexFloat || scalar_type == ScalarType::ComplexHalf,
getMPSTypeString(inputTensor) + std::string(" not supported for index.Tensor_out"));
}
}
static Tensor& masked_select_out_mps_impl(Tensor& result, const Tensor& self, const Tensor& mask) {
@ -144,7 +158,7 @@ static void dispatch_index_kernel(TensorIteratorBase& iter,
IntArrayRef index_stride,
const std::string& kernel_name,
const bool serial = false) {
validateInputData(iter, index_size, index_stride, "index.Tensor_out");
validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
if (iter.numel() == 0)
return;
if (!iter.can_use_32bit_indexing()) {
@ -186,7 +200,7 @@ static void dispatch_index_kernel(TensorIteratorBase& iter,
}
static void index_kernel_mps(TensorIteratorBase& iter, IntArrayRef index_size, IntArrayRef index_stride) {
validateInputData(iter, index_size, index_stride, "index.Tensor_out");
validateInputData(iter, index_size, index_stride, "index.Tensor_out", /*accumulate=*/false);
dispatch_index_kernel(
iter, index_size, index_stride, fmt::format("index_select_{}", getBitSizeString(iter.tensor_base(0))));
}
@ -196,7 +210,7 @@ static void index_put_kernel_mps(TensorIterator& iter,
IntArrayRef index_stride,
bool accumulate) {
@autoreleasepool {
validateInputData(iter, index_size, index_stride, "index_put_impl");
validateInputData(iter, index_size, index_stride, "index_put_impl", accumulate);
if (accumulate) {
dispatch_index_kernel(iter,
index_size,
@ -339,7 +353,14 @@ static Tensor& nonzero_out_native_mps(const Tensor& self, Tensor& out_) {
}
Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
if (self.is_complex()) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ",
"Falling back on CPU. This may have performance implications.");
Tensor out_fallback = nonzero_fallback(self);
at::native::resize_output(out_, out_fallback.sizes());
out_.copy_(out_fallback);
return out_;
} else if (self.is_complex()) {
TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ",
"Falling back on CPU. This may have performance implications.");
Tensor out_fallback = nonzero_fallback(self);
@ -424,7 +445,11 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
}
Tensor nonzero_mps(const Tensor& self) {
if (self.is_complex()) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 14.0. ",
"Falling back on CPU. This may have performance implications.");
return nonzero_fallback(self);
} else if (self.is_complex()) {
TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ",
"Falling back on CPU. This may have performance implications.");
return nonzero_fallback(self);

View File

@ -112,61 +112,6 @@ Tensor& do_metal_bmm(const Tensor& batch1, const Tensor& batch2, Tensor& output)
return output;
}
Tensor& do_metal_addmm(const Tensor& self,
const Tensor& other,
Tensor& output,
const Scalar& alpha,
const Scalar& beta,
const Tensor& bias) {
if (beta.toDouble() == 0 && alpha.toDouble() == 1) {
return do_metal_mm(self, other, output);
}
auto stream = getCurrentMPSStream();
auto device = MPSDevice::getInstance()->device();
auto matmulPSO = lib.getPipelineStateForFunc("addmm_" + mps::scalarToMetalTypeString(output));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
getMPSProfiler().beginProfileKernel(matmulPSO, "addmm", {self, other});
auto computeEncoder = stream->commandEncoder();
[computeEncoder setComputePipelineState:matmulPSO];
std::array<uint32_t, 3> sizes = {static_cast<uint32_t>(self.size(0)),
static_cast<uint32_t>(self.size(1)),
static_cast<uint32_t>(output.size(1))};
std::array<int64_t, 8> strides = {self.stride(0),
self.stride(1),
other.stride(0),
other.stride(1),
output.stride(0),
output.stride(1),
bias.stride(0),
bias.stride(1)};
union {
std::array<int64_t, 2> i64;
std::array<int32_t, 2> i32;
std::array<float, 2> f32;
} alpha_beta;
if (output.scalar_type() == kLong) {
alpha_beta.i64 = {alpha.toLong(), beta.toLong()};
} else if (c10::isIntegralType(output.scalar_type(), true)) {
alpha_beta.i32 = {alpha.toInt(), beta.toInt()};
} else {
TORCH_INTERNAL_ASSERT(c10::isFloatingType(output.scalar_type()));
alpha_beta.f32 = {alpha.toFloat(), beta.toFloat()};
}
constexpr uint32_t TILE_DIM = 16; // fastest performance from tests on multiple macs
uint32_t gridSizeX = (output.size(1) + TILE_DIM - 1) / TILE_DIM;
uint32_t gridSizeY = (self.size(0) + TILE_DIM - 1) / TILE_DIM;
MTLSize threadsPerThreadgroup = MTLSizeMake(TILE_DIM, TILE_DIM, 1);
MTLSize threadgroupsPerGrid = MTLSizeMake(gridSizeX, gridSizeY, 1);
mtl_setArgs(computeEncoder, self, other, output, bias, alpha_beta.i64, strides, sizes);
[computeEncoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
getMPSProfiler().endProfileKernel(matmulPSO);
}
});
return output;
}
std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*> do_mm(MPSGraph* graph,
const Tensor& self,
const Tensor& other) {
@ -699,6 +644,7 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
TORCH_CHECK(output.is_mps());
TORCH_CHECK(self.dim() == 2 && other.dim() == 2, "tensors must be 2-D");
TORCH_CHECK(supportedFloatingOrComplexType(self), "MPS device does not support addmm for non-float input");
TensorArg args[]{{output, "out", 0}, {bias, "self", 1}, {self, "mat1", 2}, {other, "mat2", 3}};
checkAllSameGPU(__func__, args);
@ -725,10 +671,6 @@ static Tensor& addmm_out_mps_impl(const Tensor& bias,
return output;
}
if (use_metal_mm(self, other, output)) {
return do_metal_addmm(self, other, output, alpha, beta, *bias_);
}
bool is_beta_non_zero = beta.toDouble() != 0.0;
struct CachedGraph : public mps::MPSCachedGraph {

View File

@ -297,13 +297,13 @@ static PoolSizes process_pool_sizes(const Tensor& input,
pooling_dims,
" ints");
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == pooling_dims,
TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 3,
op_name,
": stride must either be omitted, a single int, or a tuple of ",
pooling_dims,
" ints");
TORCH_CHECK(padding.size() == 1 || padding.size() == pooling_dims,
TORCH_CHECK(padding.size() == 1 || padding.size() == 3,
op_name,
": padding must either be a single int, or a tuple of ",
pooling_dims,
@ -333,22 +333,6 @@ static PoolSizes process_pool_sizes(const Tensor& input,
": pad should be at most half of effective kernel size");
}
if (pooling_dims == 2) {
const auto memory_format = input.suggest_memory_format();
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
if (memory_format == at::MemoryFormat::ChannelsLast) {
// Expect tensor in NHWC format and allow 0-dim only for N.
TORCH_CHECK((dims == 4 && valid_dims && input.size(3) != 0),
"Expected 4D (batch mode) tensor expected for input with channels_last layout"
" with optional 0 dim batch size for input, but got: ",
input.sizes());
} else {
TORCH_CHECK((dims == 3 && input.size(0) != 0 && valid_dims) || (dims == 4 && valid_dims && input.size(3) != 0),
"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got:",
input.sizes());
}
}
for (const auto dim : c10::irange(static_cast<int>(leading_dims == 2), dims)) {
TORCH_CHECK(input.size(dim) > 0, op_name, ": Expected input's non-batch dimensions to have positive length");
}
@ -802,16 +786,6 @@ static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
} // namespace mps
// TODO: The MPS graph impl can sometimes give significantly better performance
// than the Metal impl for cases where the stride is 1 in all dimensions. There
// may be a code path in the graph kernel that specifically optimizes for that
// case. We should look into implementing a specialized case in Metal so we can
// avoid using the graph impl.
static bool use_graph_for_max_pool2d(IntArrayRef kernel_size, IntArrayRef stride_) {
IntArrayRef stride = stride_.empty() ? kernel_size : stride_;
return (stride[0] == 1) && (stride.size() == 1 || stride[1] == 1);
}
Tensor mps_max_pool2d(const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
@ -819,37 +793,24 @@ Tensor mps_max_pool2d(const Tensor& input,
IntArrayRef dilation,
bool ceil_mode) {
Tensor output = at::empty({0}, input.options(), MemoryFormat::Contiguous);
bool use_graph = use_graph_for_max_pool2d(kernel_size, stride);
if (use_graph) {
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
};
mps::pool2d_template(input,
output,
std::nullopt,
std::nullopt,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
std::nullopt,
pooling_op_block,
"max_pool2d");
} else {
mps::max_pool_with_indices_out_mps_template(output,
std::nullopt,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
/*pooling_dims=*/2,
"max_pool2d");
}
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
return [mpsGraph maxPooling2DWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
};
mps::pool2d_template(input,
output,
std::nullopt,
std::nullopt,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
std::nullopt,
pooling_op_block,
"max_pool2d");
return output;
}
@ -894,45 +855,32 @@ TORCH_IMPL_FUNC(max_pool2d_with_indices_out_mps)
bool ceil_mode,
const Tensor& output,
const Tensor& indices) {
bool use_graph = use_graph_for_max_pool2d(kernel_size, stride);
if (use_graph) {
auto indices_memory_format = indices.suggest_memory_format();
auto indices_memory_format = indices.suggest_memory_format();
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
NSArray<MPSGraphTensor*>* poolOutputs =
[mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor descriptor:desc name:nil];
cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long);
return poolOutputs[0];
};
mps::pool2d_template(input,
output,
indices,
std::nullopt,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
std::nullopt,
pooling_op_block,
"max_pool2d_indices");
if (indices_memory_format == MemoryFormat::ChannelsLast) {
const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
}
mps::PoolingOpBlock pooling_op_block = ^PoolingOpFn(cachedGraph, desc) {
MPSGraph* mpsGraph = cachedGraph.graph();
NSArray<MPSGraphTensor*>* poolOutputs = [mpsGraph maxPooling2DReturnIndicesWithSourceTensor:cachedGraph.inputTensor
descriptor:desc
name:nil];
cachedGraph.indicesTensor = mps::castMPSTensor(mpsGraph, poolOutputs[1], ScalarType::Long);
return poolOutputs[0];
};
mps::pool2d_template(input,
output,
indices,
std::nullopt,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
false,
std::nullopt,
pooling_op_block,
"max_pool2d_indices");
} else {
mps::max_pool_with_indices_out_mps_template(output,
indices,
input,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
/*pooling_dims=*/2,
"max_pool2d");
if (indices_memory_format == MemoryFormat::ChannelsLast) {
const_cast<Tensor&>(indices) = indices.to(MemoryFormat::ChannelsLast);
}
}

View File

@ -152,6 +152,8 @@ static void reduction_out_mps(const Tensor& input_t,
const Tensor& output_t,
MPSReductionType reduction_type,
const std::string& func_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name);
// NS: TODO: get rid of all those shenanigans and just call reduction_op with view tensor
bool canSqueezeLastDim = true;
IntArrayRef input_shape = input_t.sizes();
@ -234,10 +236,12 @@ static void reduction_out_mps(const Tensor& input_t,
MPSGraphTensor* castInputTensor = inputTensor;
MPSDataType inputCastType = MPSDataTypeInvalid;
if (dtype.has_value() &&
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt || dtype.value() == kLong)) {
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt ||
(dtype.value() == kLong && macOS13_3_plus))) {
inputCastType = getMPSDataType(dtype.value());
} else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
inputScalarType != kComplexFloat && inputScalarType != kComplexHalf && inputScalarType != kLong) {
inputScalarType != kComplexFloat && inputScalarType != kComplexHalf &&
(inputScalarType != kLong || !macOS13_3_plus)) {
inputCastType = getMPSDataType(kFloat);
}
@ -611,6 +615,9 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
}
static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, nanmedian ? "nanmedian" : "median");
IntArrayRef input_shape = input_t.sizes();
int64_t num_in_elements = c10::multiply_integers(input_shape);
@ -627,7 +634,8 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
auto medianCachedGraph =
LookUpOrCreateCachedGraph<MedianCachedGraph>(medianKey, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* reshapedTensor = [mpsGraph reshapeTensor:castInputTensor withShape:@[ @-1 ] name:nil];
@ -685,6 +693,9 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
}
static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction_type, const std::string& func_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max");
using CachedGraph = MPSUnaryCachedGraph;
IntArrayRef input_shape = input_t.sizes();
@ -702,7 +713,8 @@ static Tensor min_max_mps_impl(const Tensor& input_t, MPSReductionType reduction
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castOutputTensor = nil;
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
NSArray<NSNumber*>* axes = getTensorAxes(input_t);
if (reduction_type == MPSReductionType::MAX) {
@ -737,6 +749,9 @@ static void min_max_out_mps(const Tensor& input_t,
const Tensor& indices_t,
MPSReductionType reduction_type,
const std::string& func_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out");
if (output_t.numel() == 0) {
return;
}
@ -774,7 +789,8 @@ static void min_max_out_mps(const Tensor& input_t,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
if (reduction_type == MPSReductionType::MAX) {
outputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axis:(NSInteger)dim_ name:nil];
@ -880,6 +896,9 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
const std::string& func_name) {
using CachedGraph = MPSUnaryCachedGraph;
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out");
int64_t dim_ = -1;
if (dim.has_value()) {
@ -934,7 +953,7 @@ static void argmax_argmin_out_mps(const Tensor& input_t,
MPSGraphTensor* castInputTensor = inputTensor;
if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
inputScalarType != kLong) {
(inputScalarType != kLong || !macOS13_3_plus)) {
castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat);
}
if (reduction_type == MPSReductionType::MAX) {
@ -1263,6 +1282,9 @@ static void all_any_common_impl_mps(const Tensor& input_t,
return;
}
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, op_name);
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, op_name.c_str());
@ -1281,7 +1303,7 @@ static void all_any_common_impl_mps(const Tensor& input_t,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
// reductionOrWithTensor:axis: will throw an internal assert if number of dimentions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
MPSGraphTensor* outputTensor = nil;
@ -1347,11 +1369,14 @@ TORCH_IMPL_FUNC(any_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
return;
}
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out");
@autoreleasepool {
std::string key = std::string("any_all_out_mps:") + getTensorsStringKey(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
// reductionOrWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
if (input_t.dim() > 4) {
@ -1395,11 +1420,14 @@ TORCH_IMPL_FUNC(all_all_out_mps)(const Tensor& input_t, const Tensor& output_t)
return;
}
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out");
@autoreleasepool {
std::string key = std::string("all_all_out_mps:") + getTensorsStringKey(input_t);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
auto castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
// reductionAndWithTensor:axes: will throw an internal assert if number of dimentions is more than 4
// See https://github.com/pytorch/pytorch/issues/95538
if (input_t.ndimension() > 4) {
@ -1484,6 +1512,9 @@ static void median_out_mps_common(const Tensor& input_t,
Tensor& indices,
const std::string& func_name,
bool nanmedian) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, "max()");
@ -1554,7 +1585,8 @@ static void median_out_mps_common(const Tensor& input_t,
getTensorsStringKey(indices);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t);
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* effectiveLengthTensor = nil;
if (nanmedian) {

View File

@ -129,8 +129,16 @@ void computeRepeatIndices(const index_t* repeat_ptr,
});
}
Tensor repeat_interleave_mps(const Tensor& repeat, std::optional<int64_t> output_size) {
Tensor repeat_interleave_mps(const Tensor& repeat_, std::optional<int64_t> output_size) {
Tensor output;
Tensor repeat = repeat_;
if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
TORCH_WARN_ONCE(
"MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
repeat = repeat.to(kInt);
}
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
output = repeat_interleave_common<index_t, computeRepeatIndices<index_t>>(repeat, output_size);
});

View File

@ -23,6 +23,125 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/ScanKernel_metallib.h>
#endif
// Generic scan implementation that handles both simple scans and scans with indices
static void scan_mps_impl(const Tensor& self,
const std::vector<Tensor>& outputs,
int64_t dim,
const std::string& op_name) {
if (outputs[0].numel() == 0) {
return;
}
const int64_t ndim = self.dim();
const int64_t wrapped_dim = maybe_wrap_dim(dim, ndim);
// Calculate dimensions for scan operation
int64_t row_size = self.size(wrapped_dim);
auto sizes = self.sizes();
bool is_innermost = (wrapped_dim == ndim - 1);
// Check if all tensors are contiguous
bool is_contiguous = self.is_contiguous();
for (const auto& output : outputs) {
is_contiguous = is_contiguous && output.is_contiguous();
}
uint32_t num_rows, num_orows, num_irows, num_threads;
if (is_innermost) {
// Treat all outer dimensions as a single dimension
num_rows = self.numel() / row_size;
num_threads = num_rows;
} else {
// Treat all outer dimensions (i.e. dim_ < dim) as one
num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + wrapped_dim);
// Treat all inner dimensions (i.e. dim > dimension) as one
num_irows = c10::multiply_integers(sizes.begin() + wrapped_dim + 1, sizes.end());
num_threads = num_orows * num_irows;
}
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
// Choose kernel based on contiguity and dimension
std::string kernel_name;
if (is_contiguous) {
kernel_name =
op_name + "_contiguous_" + (is_innermost ? "innermost_" : "outer_") + scalarToMetalTypeString(self);
} else {
kernel_name = op_name + "_strided_" + scalarToMetalTypeString(self);
}
id<MTLComputePipelineState> scanPSO = lib.getPipelineStateForFunc(kernel_name);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(scanPSO, op_name, [&]() {
std::vector<Tensor> all_tensors = {self};
all_tensors.insert(all_tensors.end(), outputs.begin(), outputs.end());
return all_tensors;
}());
[computeEncoder setComputePipelineState:scanPSO];
// Set input tensor
mtl_setBuffer(computeEncoder, self, 0);
// Set output tensors
for (size_t i = 0; i < outputs.size(); ++i) {
mtl_setBuffer(computeEncoder, outputs[i], i + 1);
}
if (is_contiguous) {
// Contiguous kernels
if (is_innermost) {
if (outputs.size() == 1) {
// Simple scan
mtl_setArgs<2>(computeEncoder, num_rows, static_cast<uint32_t>(row_size));
} else {
// Scan with indices
mtl_setArgs<3>(computeEncoder, num_rows, static_cast<uint32_t>(row_size));
}
} else {
if (outputs.size() == 1) {
// Simple scan
mtl_setArgs<2>(computeEncoder, num_orows, num_irows, static_cast<uint32_t>(row_size));
} else {
// Scan with indices
mtl_setArgs<3>(computeEncoder, num_orows, num_irows, static_cast<uint32_t>(row_size));
}
}
} else {
// Strided kernels - pass full tensor information
if (outputs.size() == 1) {
// Simple scan
mtl_setArgs<2>(computeEncoder,
self.sizes(),
self.strides(),
outputs[0].strides(),
static_cast<uint32_t>(self.ndimension()),
static_cast<uint32_t>(wrapped_dim));
} else {
// Scan with indices
mtl_setArgs<3>(computeEncoder,
self.sizes(),
self.strides(),
outputs[0].strides(),
outputs[1].strides(),
static_cast<uint32_t>(self.ndimension()),
static_cast<uint32_t>(wrapped_dim));
}
}
mtl_dispatch1DJob(computeEncoder, scanPSO, num_threads);
getMPSProfiler().endProfileKernel(scanPSO);
}
});
}
// Utility function to get 2D grid dimensions for dispatch
static std::pair<uint32_t, uint32_t> get_2d_grid_dims(const IntArrayRef& shape, const int64_t dim) {
size_t grid_x = 1;
@ -256,11 +375,19 @@ static void scan_with_indices_mps_impl(const Tensor& self,
} // namespace mps
void cummax_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax");
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummax");
} else {
mps::scan_mps_impl(self, {values, indices}, dim, "cummax");
}
}
void cummin_helper_mps(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin");
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
mps::scan_with_indices_mps_impl(self, values, indices, dim, "cummin");
} else {
mps::scan_mps_impl(self, {values, indices}, dim, "cummin");
}
}
Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) {
@ -275,7 +402,11 @@ Tensor& _logcumsumexp_out_mps(const Tensor& self, int64_t dim, Tensor& result) {
return result;
}
mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp");
if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
mps::scan_simple_mps_impl(self, result, wrap_dim, "logcumsumexp");
} else {
mps::scan_mps_impl(self, {result}, wrap_dim, "logcumsumexp");
}
return result;
}

View File

@ -26,6 +26,9 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
const Tensor& indices) {
using namespace mps;
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out");
if (self.numel() == 0) {
return;
}
@ -52,7 +55,8 @@ TORCH_IMPL_FUNC(sort_stable_out_mps)
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), input_shape);
MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self);
MPSGraphTensor* castInputTensor =
castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* sortedTensor = [mpsGraph sortWithTensor:castInputTensor
axis:(NSInteger)dim
descending:(BOOL)descending

View File

@ -297,6 +297,9 @@ static void isin_Tensor_Tensor_out_mps(const Tensor& elements,
const auto common_type = at::result_type(elements, test_elements);
TORCH_CHECK(elements.is_mps() && test_elements.is_mps());
TORCH_CHECK(is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) || supportedFloatingType(common_type),
"isin_Tensor_Tensor_out only works on floating types on MPS for pre MacOS_14_0. Received dtype: ",
common_type);
@autoreleasepool {
std::string key = op_name + getTensorsStringKey({elements, test_elements}) + std::to_string(invert);

View File

@ -208,12 +208,28 @@ Tensor& logical_not_out_mps(const Tensor& self, Tensor& output) {
}
Tensor& angle_out_mps(const Tensor& self, Tensor& output) {
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
});
return output;
if (mps::supportsComplex()) {
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
auto realPart = [mpsGraph realPartOfTensor:inputTensor name:nil];
auto imagPart = [mpsGraph imaginaryPartOfTensor:inputTensor name:nil];
return [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:realPart name:nil];
});
return output;
} else {
TORCH_CHECK(!self.is_complex(), "MPS does not support angle with complex input on macOS13")
mps::unary_op(self, output, "angle_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
// On macOS 13 with non-complex input, realPartOfTensor and imaginaryPartOfTensor are
// not available, and NaN is not propagated correctly:
auto imagPart = [mpsGraph constantWithScalar:0.0 shape:inputTensor.shape dataType:inputTensor.dataType];
auto result = [mpsGraph atan2WithPrimaryTensor:imagPart secondaryTensor:inputTensor name:nil];
auto nanMask = [mpsGraph isNaNWithTensor:inputTensor name:nil];
return [mpsGraph selectWithPredicateTensor:nanMask
truePredicateTensor:inputTensor
falsePredicateTensor:result
name:nil];
});
return output;
}
}
Tensor angle_mps(const Tensor& self) {
@ -346,6 +362,7 @@ static void cumulative_op_impl(const Tensor& self,
const Tensor& result,
MPSCumulativeOpType cumulativeOpType,
const std::string& op_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
TORCH_CHECK(wrapped_dim >= 0 && wrapped_dim < std::max(1LL, self.ndimension()),
@ -364,6 +381,11 @@ static void cumulative_op_impl(const Tensor& self,
bool castInputData = (isIntegralType(input.scalar_type(), true) && input.scalar_type() != ScalarType::Int &&
input.scalar_type() != ScalarType::Long);
TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
"MPS does not support ",
op_name,
" op with int64 input. Support has been added in macOS 13.3");
mps::unary_op(
input, result, op_name + std::to_string(dim), ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
if (castInputData) {
@ -418,9 +440,17 @@ TORCH_IMPL_FUNC(sgn_out_mps)(const Tensor& self, const Tensor& output) {
Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) {
TORCH_CHECK(self.is_complex());
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
});
if (!mps::supportsComplex()) {
if (!result.is_same_size(self)) {
result.resize_(self.sizes());
}
at::real(result).copy_(at::real(self));
at::imag(result).copy_(at::neg(at::imag(self)));
} else {
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
});
}
return result;
}

View File

@ -7423,7 +7423,6 @@
dispatch:
SparseCPU: _coalesce_sparse_cpu
SparseCUDA: _coalesce_sparse_cuda
SparseMPS: _coalesce_sparse_mps
autogen: _coalesce.out
- func: is_coalesced(Tensor self) -> bool
@ -15013,7 +15012,6 @@
- func: _scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
NestedTensorCUDA: _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda
tags: nondeterministic_seeded
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
@ -15046,11 +15044,6 @@
CUDA: _cudnn_attention_forward
tags: nondeterministic_seeded
- func: _cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
dispatch:
CUDA: _cudnn_attention_backward
tags: nondeterministic_seeded
- func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor
variants: function
dispatch:

View File

@ -349,63 +349,6 @@ _scaled_dot_product_cudnn_attention_nestedtensor_cuda(
return std::make_tuple(std::move(attention), std::move(log_sumexp), cumulative_sequence_length_q, cumulative_sequence_length_kv, max_seqlen_batch_q, max_seqlen_batch_kv, std::move(cudnn_seed), std::move(cudnn_offset), Tensor());
}
std::tuple<Tensor, Tensor, Tensor> _scaled_dot_product_cudnn_attention_nestedtensor_backward_cuda(
const Tensor& grad_out,
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& out,
const Tensor& logsumexp,
const Tensor& philox_seed,
const Tensor& philox_offset,
const Tensor& attn_bias,
const Tensor& cum_seq_q,
const Tensor& cum_seq_k,
const int64_t max_q,
const int64_t max_k,
double dropout_p,
bool is_causal,
std::optional<double> scale) {
if (!grad_out.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
}
auto [
grad_out_buffer_reshaped,
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
output_buffer_reshaped] =
preprocessing::sdpa_nested_preprocessing_backward(
grad_out,
query,
key,
value,
out,
cum_seq_q,
cum_seq_k,
max_q,
max_k);
auto [dq, dk, dv] = at::_cudnn_attention_backward(grad_out_buffer_reshaped,
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
output_buffer_reshaped,
logsumexp,
philox_seed,
philox_offset,
attn_bias,
cum_seq_q,
cum_seq_k,
max_q,
max_k,
dropout_p,
is_causal,
scale);
return std::make_tuple(std::move(dq), std::move(dk), std::move(dv));
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
const at::Tensor& grad_out_,
const at::Tensor& query,

View File

@ -333,14 +333,14 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
weight.scalar_type() == at::ScalarType::Float ||
weight.scalar_type() == at::ScalarType::Half,
"'embedding_bag_byte_prepack' only support float32 or float16.");
const auto weight_sizes = weight.sym_sizes();
const auto cols_dim = weight.ndimension() - 1;
const auto embedding_cols = weight_sizes[cols_dim];
const auto weight_sizes = weight.sizes();
const auto cols_dim = weight_sizes.size() - 1;
const int32_t embedding_cols = static_cast<int32_t>(weight_sizes[cols_dim]);
// Add 8 bytes per column to store FP32 scale and zero_point per row.
const auto output_columns = embedding_cols + 2 * sizeof(float);
const int32_t output_columns = static_cast<int32_t>(embedding_cols + 2 * sizeof(float));
// Adjust output dimensions to account for FP32 scale and zero_points.
auto output_shape = weight_sizes.vec();
std::vector<int64_t> output_shape = weight_sizes.vec();
output_shape.at(cols_dim) = output_columns;
at::SymDimVector output_shape_vec(output_shape);

View File

@ -1,220 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_coalesce_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/zeros_native.h>
#endif
namespace at::native {
using namespace mps;
using namespace at::sparse;
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/Sparse_metallib.h>
#endif
static Tensor flatten_indices(const Tensor& indices, IntArrayRef size) {
TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D");
TORCH_CHECK(static_cast<size_t>(indices.size(0)) == size.size(),
"flatten_indices: indices.size(0) must equal size.size()");
int64_t sparse_dim = indices.size(0);
int64_t nnz = indices.size(1);
if (nnz == 0) {
return at::empty({0}, indices.options().dtype(kLong));
}
std::vector<int64_t> strides(sparse_dim);
strides[sparse_dim - 1] = 1;
for (int64_t i = sparse_dim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * size[i + 1];
}
Tensor flat_indices = at::empty({nnz}, indices.options().dtype(kLong));
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, indices, strides, flat_indices, sparse_dim, nnz);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
return flat_indices;
}
static Tensor compute_output_positions(const Tensor& is_unique) {
int64_t nnz = is_unique.size(0);
if (nnz == 0) {
return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt));
}
Tensor positions = at::empty({nnz}, TensorOptions().device(kMPS).dtype(kInt));
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("compute_output_positions_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, is_unique, positions);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
return positions;
}
static Tensor compute_output_positions_parallel(const Tensor& is_unique) {
int64_t nnz = is_unique.size(0);
if (nnz == 0) {
return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt));
}
// for small arrays, use simple kernel
// speed of the naive kernel drops off after 4096 nnz elements
if (nnz <= 4096) {
return compute_output_positions(is_unique);
}
auto stream = getCurrentMPSStream();
Tensor positions = is_unique.to(kInt);
// Kogge-Stone parallel prefix sum
Tensor positions_cloned = positions.clone();
for (int64_t stride = 1; stride < nnz; stride *= 2) {
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("kogge_stone_step");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, positions, positions_cloned, stride);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
std::swap(positions, positions_cloned);
}
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("shift_right_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, positions, positions_cloned);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
return positions_cloned;
}
static std::pair<Tensor, int32_t> mark_unique_and_count(const Tensor& flat_indices) {
int64_t nnz = flat_indices.size(0);
if (nnz == 0) {
return {at::empty({0}, flat_indices.options().dtype(kBool)), 0};
}
Tensor is_unique = at::empty({nnz}, flat_indices.options().dtype(kBool));
Tensor count_result = at::zeros({1}, flat_indices.options().dtype(kInt));
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("mark_unique_positions_and_count_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, flat_indices, is_unique, count_result);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
int32_t num_unique = count_result.item<int32_t>();
return {is_unique, num_unique};
}
SparseTensor _coalesce_sparse_mps(const SparseTensor& self) {
int64_t nnz = self._nnz();
TORCH_INTERNAL_ASSERT(!self.is_coalesced());
if (nnz < 2) {
SparseTensor dst = self.clone();
dst._coalesced_(true);
return dst;
}
Tensor indices = self._indices();
Tensor values = self._values();
Tensor flat_indices = flatten_indices(indices, self.sizes());
Tensor sorted_order = flat_indices.argsort();
Tensor flat_indices_sorted = flat_indices.index({sorted_order});
values = values.index({sorted_order});
indices = indices.index_select(1, sorted_order);
auto unique_info = mark_unique_and_count(flat_indices_sorted);
Tensor is_unique = unique_info.first;
int32_t newNnz = unique_info.second;
Tensor output_positions = compute_output_positions_parallel(is_unique);
Tensor out_indices = at::empty({indices.size(0), newNnz}, indices.options());
auto outValuesSize = values.sizes().vec();
outValuesSize[0] = newNnz;
Tensor out_values = at::zeros(outValuesSize, values.options());
Tensor is_unique_local = is_unique;
int64_t sparse_dim = indices.size(0);
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("coalesce_with_positions_kernel_" + scalarToMetalTypeString(values));
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
const uint32_t numThreads = static_cast<uint32_t>(nnz);
const uint32_t valueSize = static_cast<uint32_t>(values.numel() / nnz);
mtl_setArgs(encoder,
flat_indices_sorted,
indices,
values,
is_unique_local,
output_positions,
out_indices,
out_values,
numThreads,
valueSize,
sparse_dim,
newNnz);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
SparseTensor result = _sparse_coo_tensor_unsafe_symint(out_indices, out_values, self.sym_sizes())._coalesced_(true);
return result;
}
} // namespace at::native

Some files were not shown because too many files have changed in this diff Show More