Compare commits

..

1 Commits

Author SHA1 Message Date
449a8bff92 Updated docs to add the error case for torch.multinomial
Summary: Updated docs to add the error condition for torch.multinomial

Test Plan: No change in code

Reviewers:

Subscribers: @drisspg

Tasks:

Tags:
2024-05-03 13:00:37 -07:00
929 changed files with 60638 additions and 34771 deletions

View File

@ -204,7 +204,7 @@ case "$image" in
PROTOBUF=yes
DB=yes
VISION=yes
ROCM_VERSION=6.0
ROCM_VERSION=5.7
NINJA_VERSION=1.9.0
CONDA_CMAKE=yes
TRITON=yes
@ -215,7 +215,7 @@ case "$image" in
PROTOBUF=yes
DB=yes
VISION=yes
ROCM_VERSION=6.1
ROCM_VERSION=6.0
NINJA_VERSION=1.9.0
CONDA_CMAKE=yes
TRITON=yes
@ -306,12 +306,6 @@ case "$image" in
DB=yes
VISION=yes
CONDA_CMAKE=yes
# snadampal: skipping sccache due to the following issue
# https://github.com/pytorch/pytorch/issues/121559
SKIP_SCCACHE_INSTALL=yes
# snadampal: skipping llvm src build install because the current version
# from pytorch/llvm:9.0.1 is x86 specific
SKIP_LLVM_SRC_BUILD_INSTALL=yes
;;
*)
# Catch-all for builds that are not hardcoded.
@ -366,7 +360,7 @@ if [[ "$image" == *cuda* && ${OS} == "ubuntu" ]]; then
fi
# Build image
docker build \
DOCKER_BUILDKIT=1 docker build \
--no-cache \
--progress=plain \
--build-arg "BUILD_ENVIRONMENT=${image}" \
@ -405,8 +399,6 @@ docker build \
--build-arg "EXECUTORCH=${EXECUTORCH}" \
--build-arg "BASEKIT_VERSION=${BASEKIT_VERSION}" \
--build-arg "ACL=${ACL:-}" \
--build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \
--build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \
-f $(dirname ${DOCKERFILE})/Dockerfile \
-t "$tmp_tag" \
"$@" \

View File

@ -113,6 +113,7 @@ install_centos() {
glibc-devel \
glibc-headers \
glog-devel \
hiredis-devel \
libstdc++-devel \
libsndfile-devel \
make \

View File

@ -4,6 +4,11 @@ set -ex
install_ubuntu() {
apt-get update
apt-get install -y --no-install-recommends \
libhiredis-dev \
libleveldb-dev \
liblmdb-dev \
libsnappy-dev
# Cleanup
apt-get autoclean && apt-get clean
@ -15,6 +20,12 @@ install_centos() {
# See http://fedoraproject.org/wiki/EPEL
yum --enablerepo=extras install -y epel-release
yum install -y \
hiredis-devel \
leveldb-devel \
lmdb-devel \
snappy-devel
# Cleanup
yum clean all
rm -rf /var/cache/yum

View File

@ -61,10 +61,6 @@ install_ubuntu() {
rocprofiler-dev \
roctracer-dev
if [[ $(ver $ROCM_VERSION) -ge $(ver 6.1) ]]; then
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev
fi
# precompiled miopen kernels added in ROCm 3.5, renamed in ROCm 5.5
# search for all unversioned packages
# if search fails it will abort this script; use true to avoid case where search fails

View File

@ -263,10 +263,10 @@ unittest-xml-reporting<=3.2.0,>=2.0.0
#Pinned versions:
#test that import:
#lintrunner is supported on aarch64-linux only from 0.12.4 version
lintrunner==0.12.5
#wheel not found on aarch64, and source build requires rust
lintrunner==0.10.7 ; platform_machine == "x86_64"
#Description: all about linters!
#Pinned versions: 0.12.5
#Pinned versions: 0.10.7
#test that import:
rockset==1.0.3
@ -279,9 +279,9 @@ ghstack==0.8.0
#Pinned versions: 0.8.0
#test that import:
jinja2==3.1.4
jinja2==3.1.3
#Description: jinja2 template engine
#Pinned versions: 3.1.4
#Pinned versions: 3.1.3
#test that import:
pytest-cpp==2.3.0

View File

@ -169,11 +169,9 @@ RUN rm install_acl.sh
ENV INSTALLED_ACL ${ACL}
# Install ccache/sccache (do this last, so we get priority in PATH)
ARG SKIP_SCCACHE_INSTALL
COPY ./common/install_cache.sh install_cache.sh
ENV PATH /opt/cache/bin:$PATH
RUN if [ -z "${SKIP_SCCACHE_INSTALL}" ]; then bash ./install_cache.sh; fi
RUN rm install_cache.sh
RUN bash ./install_cache.sh && rm install_cache.sh
# Add jni.h for java host build
COPY ./common/install_jni.sh install_jni.sh
@ -190,9 +188,7 @@ ARG BUILD_ENVIRONMENT
ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT}
# Install LLVM dev version (Defined in the pytorch/builder github repository)
ARG SKIP_LLVM_SRC_BUILD_INSTALL
COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm
RUN if [ -n "${SKIP_LLVM_SRC_BUILD_INSTALL}" ]; then set -eu; rm -rf /opt/llvm; fi
# AWS specific CUDA build guidance
ENV TORCH_CUDA_ARCH_LIST Maxwell

View File

@ -81,22 +81,7 @@ if ! which conda; then
export USE_MKLDNN=0
fi
else
# CMAKE_PREFIX_PATH precedences
# 1. $CONDA_PREFIX, if defined. This follows the pytorch official build instructions.
# 2. /opt/conda/envs/py_${ANACONDA_PYTHON_VERSION}, if ANACONDA_PYTHON_VERSION defined.
# This is for CI, which defines ANACONDA_PYTHON_VERSION but not CONDA_PREFIX.
# 3. $(conda info --base). The fallback value of pytorch official build
# instructions actually refers to this.
# Commonly this is /opt/conda/
if [[ -v CONDA_PREFIX ]]; then
export CMAKE_PREFIX_PATH=${CONDA_PREFIX}
elif [[ -v ANACONDA_PYTHON_VERSION ]]; then
export CMAKE_PREFIX_PATH="/opt/conda/envs/py_${ANACONDA_PYTHON_VERSION}"
else
# already checked by `! which conda`
CMAKE_PREFIX_PATH="$(conda info --base)"
export CMAKE_PREFIX_PATH
fi
export CMAKE_PREFIX_PATH=/opt/conda
# Workaround required for MKL library linkage
# https://github.com/pytorch/pytorch/issues/119557
@ -391,8 +376,4 @@ if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]];
python tools/stats/export_test_times.py
fi
# snadampal: skipping it till sccache support added for aarch64
# https://github.com/pytorch/pytorch/issues/121559
if [[ "$BUILD_ENVIRONMENT" != *aarch64* ]]; then
print_sccache_stats
fi
print_sccache_stats

View File

@ -181,11 +181,6 @@ if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then
export PATH="$HOME/.local/bin:$PATH"
fi
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
# TODO: revisit this once the CI is stabilized on aarch64 linux
export VALGRIND=OFF
fi
install_tlparse
# DANGER WILL ROBINSON. The LD_PRELOAD here could cause you problems
@ -310,23 +305,22 @@ test_dynamo_shard() {
test_inductor_distributed() {
# Smuggle a few multi-gpu tests here so that we don't have to request another large node
echo "Testing multi_gpu tests in test_torchinductor"
python test/run_test.py -i inductor/test_torchinductor.py -k test_multi_gpu --verbose
python test/run_test.py -i inductor/test_aot_inductor.py -k test_non_default_cuda_device --verbose
python test/run_test.py -i inductor/test_aot_inductor.py -k test_replicate_on_devices --verbose
python test/run_test.py -i distributed/test_c10d_functional_native.py --verbose
python test/run_test.py -i distributed/_tensor/test_dtensor_compile.py --verbose
python test/run_test.py -i distributed/tensor/parallel/test_fsdp_2d_parallel.py --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_comm.py --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_multi_group --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_with_activation_checkpointing --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_mlp --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_hsdp --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_transformer_checkpoint_resume --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_training.py -k test_gradient_accumulation --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_frozen.py --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype --verbose
python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype --verbose
python test/run_test.py -i distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration --verbose
pytest test/inductor/test_torchinductor.py -k test_multi_gpu
pytest test/inductor/test_aot_inductor.py -k test_non_default_cuda_device
pytest test/inductor/test_aot_inductor.py -k test_replicate_on_devices
pytest test/distributed/test_c10d_functional_native.py
pytest test/distributed/_tensor/test_dtensor_compile.py
pytest test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
pytest test/distributed/_composable/fsdp/test_fully_shard_comm.py
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_multi_group
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_with_activation_checkpointing
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_mlp
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_hsdp
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_transformer_checkpoint_resume
pytest test/distributed/_composable/fsdp/test_fully_shard_frozen.py
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype
pytest test/distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
# with if required # gpus aren't available
@ -522,11 +516,6 @@ test_single_dynamo_benchmark() {
fi
}
test_inductor_micro_benchmark() {
TEST_REPORTS_DIR=$(pwd)/test/test-micro-reports
python benchmarks/gpt_fast/benchmark.py
}
test_dynamo_benchmark() {
# Usage: test_dynamo_benchmark huggingface 0
TEST_REPORTS_DIR=$(pwd)/test/test-reports
@ -1163,33 +1152,11 @@ test_executorch() {
assert_git_not_dirty
}
test_linux_aarch64(){
python test/run_test.py --include test_modules test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \
test_transformers test_multiprocessing test_numpy_interop --verbose
# Dynamo tests
python test/run_test.py --include dynamo/test_compile dynamo/test_backends dynamo/test_comptime dynamo/test_config \
dynamo/test_functions dynamo/test_fx_passes_pre_grad dynamo/test_interop dynamo/test_model_output dynamo/test_modules \
dynamo/test_optimizers dynamo/test_recompile_ux dynamo/test_recompiles --verbose
# Inductor tests
python test/run_test.py --include inductor/test_torchinductor inductor/test_benchmark_fusion inductor/test_codecache \
inductor/test_config inductor/test_control_flow inductor/test_coordinate_descent_tuner inductor/test_fx_fusion \
inductor/test_group_batch_fusion inductor/test_inductor_freezing inductor/test_inductor_utils \
inductor/test_inplacing_pass inductor/test_kernel_benchmark inductor/test_layout_optim \
inductor/test_max_autotune inductor/test_memory_planning inductor/test_metrics inductor/test_multi_kernel inductor/test_pad_mm \
inductor/test_pattern_matcher inductor/test_perf inductor/test_profiler inductor/test_select_algorithm inductor/test_smoke \
inductor/test_split_cat_fx_passes inductor/test_standalone_compile inductor/test_torchinductor \
inductor/test_torchinductor_codegen_dynamic_shapes inductor/test_torchinductor_dynamic_shapes --verbose
}
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
(cd test && python -c "import torch; print(torch.__config__.show())")
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
fi
if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
test_linux_aarch64
elif [[ "${TEST_CONFIG}" == *backward* ]]; then
if [[ "${TEST_CONFIG}" == *backward* ]]; then
test_forward_backward_compatibility
# Do NOT add tests after bc check tests, see its comment.
elif [[ "${TEST_CONFIG}" == *xla* ]]; then
@ -1214,8 +1181,6 @@ elif [[ "$TEST_CONFIG" == deploy ]]; then
test_torch_deploy
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
test_inductor_micro_benchmark
elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then
install_torchvision
id=$((SHARD_NUMBER-1))

View File

@ -17,22 +17,22 @@ set PATH=C:\Program Files\CMake\bin;C:\Program Files\7-Zip;C:\ProgramData\chocol
set INSTALLER_DIR=%SCRIPT_HELPERS_DIR%\installation-helpers
call %INSTALLER_DIR%\install_magma.bat
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
call %INSTALLER_DIR%\install_sccache.bat
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
:: Miniconda has been installed as part of the Windows AMI with all the dependencies.
:: We just need to activate it here
call %INSTALLER_DIR%\activate_miniconda3.bat
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
call pip install mkl-include==2021.4.0 mkl-devel==2021.4.0
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
:: Override VS env here
pushd .
@ -41,8 +41,8 @@ if "%VC_VERSION%" == "" (
) else (
call "C:\Program Files (x86)\Microsoft Visual Studio\%VC_YEAR%\%VC_PRODUCT%\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%VC_VERSION%
)
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
@echo on
popd
@ -52,12 +52,12 @@ set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION%
if x%CUDA_VERSION:.=%==x%CUDA_VERSION% (
echo CUDA version %CUDA_VERSION% format isn't correct, which doesn't contain '.'
goto fail
exit /b 1
)
rem version transformer, for example 10.1 to 10_1.
if x%CUDA_VERSION:.=%==x%CUDA_VERSION% (
echo CUDA version %CUDA_VERSION% format isn't correct, which doesn't contain '.'
goto fail
exit /b 1
)
set VERSION_SUFFIX=%CUDA_VERSION:.=_%
set CUDA_PATH_V%VERSION_SUFFIX%=%CUDA_PATH%
@ -101,8 +101,8 @@ if "%USE_CUDA%"=="1" (
:: CMake requires a single command as CUDA_NVCC_EXECUTABLE, so we push the wrappers
:: randomtemp.exe and sccache.exe into a batch file which CMake invokes.
curl -kL https://github.com/peterjc123/randomtemp-rust/releases/download/v0.4/randomtemp.exe --output %TMP_DIR_WIN%\bin\randomtemp.exe
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
echo @"%TMP_DIR_WIN%\bin\randomtemp.exe" "%TMP_DIR_WIN%\bin\sccache.exe" "%CUDA_PATH%\bin\nvcc.exe" %%* > "%TMP_DIR%/bin/nvcc.bat"
cat %TMP_DIR%/bin/nvcc.bat
set CUDA_NVCC_EXECUTABLE=%TMP_DIR%/bin/nvcc.bat
@ -114,8 +114,8 @@ if "%USE_CUDA%"=="1" (
set
python setup.py bdist_wheel
if errorlevel 1 goto fail
if not errorlevel 0 goto fail
if errorlevel 1 exit /b
if not errorlevel 0 exit /b
sccache --show-stats
python -c "import os, glob; os.system('python -mpip install --no-index --no-deps ' + glob.glob('dist/*.whl')[0])"
(
@ -135,8 +135,3 @@ python -c "import os, glob; os.system('python -mpip install --no-index --no-deps
sccache --show-stats --stats-format json | jq .stats > sccache-stats-%BUILD_ENVIRONMENT%-%OUR_GITHUB_JOB_ID%.json
sccache --stop-server
exit /b 0
:fail
exit /b 1

View File

@ -54,7 +54,6 @@ per-file-ignores =
torch/ao/quantization/fx/_decomposed.py: TOR901
torch/distributed/_functional_collectives.py: TOR901
torch/distributed/_spmd/data_parallel.py: TOR901
torch/distributed/_tensor/_collective_utils.py: TOR901
optional-ascii-coding = True
exclude =
./.git,

View File

@ -1 +1 @@
d23a6e1664d20707c11781299611436e1f0c104f
2c4665ffbb64f03f5d18016d3398af4ac4da5f03

View File

@ -1 +1 @@
e3fc03314dab5f44e3ed9ccbba6c15fbca3285cd
58a412cb271a3f98ae2e01fd1d24bdbb66645d4e

11
.github/labeler.yml vendored
View File

@ -58,17 +58,6 @@
- third_party/mkl-dnn.BUILD
- torch/csrc/jit/codegen/onednn/**
- test/test_jit_llga_fuser.py
- test/test_mkldnn.py
"ciflow/linux-aarch64":
- third_party/ideep
- caffe2/ideep/**
- caffe2/python/ideep/**
- cmake/Modules/FindMKLDNN.cmake
- third_party/mkl-dnn.BUILD
- torch/csrc/jit/codegen/onednn/**
- test/test_jit_llga_fuser.py
- test/test_mkldnn.py
"module: amp (automated mixed precision)":
- torch/amp/**

View File

@ -29,12 +29,10 @@
approved_by:
- BowenBao
- justinchuby
- liqunfu
- shubhambhokare1
- thiagocrepaldi
- titaiwangms
- wschin
- xadupre
mandatory_checks_name:
- EasyCLA
- Lint

View File

@ -8,8 +8,6 @@ ciflow_push_tags:
- ciflow/binaries_wheel
- ciflow/inductor
- ciflow/inductor-perf-compare
- ciflow/inductor-micro-benchmark
- ciflow/linux-aarch64
- ciflow/mps
- ciflow/nightly
- ciflow/periodic

View File

@ -5,7 +5,7 @@
# functorch/docs/requirements.txt
# .ci/docker/requirements-ci.txt
boto3==1.19.12
jinja2==3.1.4
jinja2==3.1.3
lintrunner==0.10.7
ninja==1.10.0.post1
nvidia-ml-py==11.525.84

View File

@ -1,11 +1,7 @@
#!/bin/bash
set -x
if [ -z "$1" ]; then
echo "Need wheel location argument" && exit 1
fi
WHEELHOUSE_DIR=$1
WHEELHOUSE_DIR=/artifacts
PATCHELF_BIN=patchelf
ROCM_LIB=backends/amd/lib
ROCM_LD=backends/amd/llvm/bin

View File

@ -157,10 +157,10 @@ def build_triton(
if build_rocm:
check_call(
[f"{SCRIPT_DIR}/amd/patch_triton_wheel.sh", Path.cwd()],
[f"{SCRIPT_DIR}/amd/patch_triton_wheel.sh"],
cwd=triton_basedir,
shell=True,
)
return Path.cwd() / whl_path.name

View File

@ -13,16 +13,16 @@ architectures:
import os
from typing import Dict, List, Optional, Tuple
CUDA_ARCHES = ["11.8", "12.1", "12.4"]
CUDA_ARCHES = ["11.8", "12.1"]
CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.1": "12.1.1", "12.4": "12.4.0"}
CUDA_ARCHES_FULL_VERSION = {"11.8": "11.8.0", "12.1": "12.1.1"}
CUDA_ARCHES_CUDNN_VERSION = {"11.8": "8", "12.1": "8", "12.4": "8"}
CUDA_ARCHES_CUDNN_VERSION = {"11.8": "8", "12.1": "8"}
ROCM_ARCHES = ["6.0", "6.1"]
ROCM_ARCHES = ["5.7", "6.0"]
CPU_CXX11_ABI_ARCH = ["cpu-cxx11-abi"]
@ -58,20 +58,6 @@ PYTORCH_EXTRA_INSTALL_REQUIREMENTS = {
"nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64'"
),
"12.4": (
"nvidia-cuda-nvrtc-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-runtime-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu12==8.9.7.29; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cublas-cu12==12.4.2.65; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cufft-cu12==11.2.0.44; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-curand-cu12==10.3.5.119; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusolver-cu12==11.6.0.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cusparse-cu12==12.3.0.142; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu12==2.20.5; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvjitlink-cu12==12.4.99; platform_system == 'Linux' and platform_machine == 'x86_64'"
),
}
@ -338,7 +324,7 @@ def generate_wheels_matrix(
)
# 12.1 linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install
if arch_version in ["12.4", "12.1", "11.8"] and os == "linux":
if arch_version in ["12.1", "11.8"] and os == "linux":
ret.append(
{
"python_version": python_version,
@ -381,6 +367,5 @@ def generate_wheels_matrix(
return ret
validate_nccl_dep_consistency("12.4")
validate_nccl_dep_consistency("12.1")
validate_nccl_dep_consistency("11.8")

View File

@ -21,8 +21,6 @@ DOCKER_IMAGE_TYPES = ["runtime", "devel"]
def generate_docker_matrix() -> Dict[str, List[Dict[str, str]]]:
ret: List[Dict[str, str]] = []
# CUDA amd64 Docker images are available as both runtime and devel while
# CPU arm64 image is only available as runtime.
for cuda, version in generate_binary_build_matrix.CUDA_ARCHES_FULL_VERSION.items():
for image in DOCKER_IMAGE_TYPES:
ret.append(
@ -33,19 +31,9 @@ def generate_docker_matrix() -> Dict[str, List[Dict[str, str]]]:
cuda
],
"image_type": image,
"platform": "linux/amd64",
"platform": "linux/arm64,linux/amd64",
}
)
ret.append(
{
"cuda": "cpu",
"cuda_full_version": "",
"cudnn_version": "",
"image_type": "runtime",
"platform": "linux/arm64",
}
)
return {"include": ret}

View File

@ -46,7 +46,7 @@ env:
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
!{{ common.concurrency(build_environment) }}
jobs:

View File

@ -48,7 +48,7 @@ env:
BUILD_ENVIRONMENT: !{{ build_environment }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
{%- if cross_compile_arm64 %}
CROSS_COMPILE_ARM64: 1
{% endif %}

View File

@ -37,7 +37,7 @@ jobs:
device: ["cuda", "rocm"]
include:
- device: "rocm"
rocm_version: "6.1"
rocm_version: "6.0"
- device: "cuda"
rocm_version: ""
timeout-minutes: 40

View File

@ -7,7 +7,6 @@ on:
- Dockerfile
- docker.Makefile
- .github/workflows/docker-release.yml
- .github/scripts/generate_docker_release_matrix.py
push:
branches:
- nightly
@ -130,27 +129,17 @@ jobs:
if: ${{ github.event.ref == 'refs/heads/nightly' && matrix.image_type == 'runtime' }}
run: |
PYTORCH_DOCKER_TAG="${PYTORCH_VERSION}-cuda${CUDA_VERSION_SHORT}-cudnn${CUDNN_VERSION}-runtime"
CUDA_SUFFIX="-cu${CUDA_VERSION}"
if [[ ${CUDA_VERSION_SHORT} == "cpu" ]]; then
PYTORCH_DOCKER_TAG="${PYTORCH_VERSION}-runtime"
CUDA_SUFFIX=""
fi
PYTORCH_NIGHTLY_COMMIT=$(docker run ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_DOCKER_TAG}" \
python -c 'import torch; print(torch.version.git_version[:7],end="")')
docker tag ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_DOCKER_TAG}" \
ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}${CUDA_SUFFIX}"
docker push ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}${CUDA_SUFFIX}"
# Please note, here we ned to pin specific verison of CUDA as with latest label
if [[ ${CUDA_VERSION_SHORT} == "12.1" ]]; then
docker tag ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}${CUDA_SUFFIX}" \
ghcr.io/pytorch/pytorch-nightly:latest
docker push ghcr.io/pytorch/pytorch-nightly:latest
fi
ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}"
docker push ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}"
docker tag ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}" \
ghcr.io/pytorch/pytorch-nightly:latest
docker push ghcr.io/pytorch/pytorch-nightly:latest
- name: Teardown Linux
uses: pytorch/test-infra/.github/actions/teardown-linux@main
if: always()

View File

@ -31,7 +31,7 @@ env:
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: linux-aarch64-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

View File

@ -31,7 +31,7 @@ env:
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: linux-binary-conda-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
@ -222,69 +222,6 @@ jobs:
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_8-cuda12_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.8"
runs_on: linux.24xlarge
build_name: conda-py3_8-cuda12_4
build_environment: linux-binary-conda
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_8-cuda12_4-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: conda-py3_8-cuda12_4-build
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.8"
build_name: conda-py3_8-cuda12_4
build_environment: linux-binary-conda
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_8-cuda12_4-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: conda-py3_8-cuda12_4-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.8"
build_name: conda-py3_8-cuda12_4
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_9-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
@ -470,69 +407,6 @@ jobs:
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_9-cuda12_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.9"
runs_on: linux.24xlarge
build_name: conda-py3_9-cuda12_4
build_environment: linux-binary-conda
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_9-cuda12_4-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: conda-py3_9-cuda12_4-build
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.9"
build_name: conda-py3_9-cuda12_4
build_environment: linux-binary-conda
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_9-cuda12_4-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: conda-py3_9-cuda12_4-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.9"
build_name: conda-py3_9-cuda12_4
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_10-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
@ -718,69 +592,6 @@ jobs:
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_10-cuda12_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.10"
runs_on: linux.24xlarge
build_name: conda-py3_10-cuda12_4
build_environment: linux-binary-conda
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_10-cuda12_4-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: conda-py3_10-cuda12_4-build
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.10"
build_name: conda-py3_10-cuda12_4
build_environment: linux-binary-conda
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_10-cuda12_4-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: conda-py3_10-cuda12_4-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.10"
build_name: conda-py3_10-cuda12_4
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_11-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
@ -966,69 +777,6 @@ jobs:
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_11-cuda12_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.11"
runs_on: linux.24xlarge
build_name: conda-py3_11-cuda12_4
build_environment: linux-binary-conda
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_11-cuda12_4-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: conda-py3_11-cuda12_4-build
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.11"
build_name: conda-py3_11-cuda12_4
build_environment: linux-binary-conda
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_11-cuda12_4-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: conda-py3_11-cuda12_4-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.11"
build_name: conda-py3_11-cuda12_4
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_12-cpu-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
@ -1213,66 +961,3 @@ jobs:
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
conda-py3_12-cuda12_4-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.12"
runs_on: linux.24xlarge
build_name: conda-py3_12-cuda12_4
build_environment: linux-binary-conda
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_12-cuda12_4-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: conda-py3_12-cuda12_4-build
uses: ./.github/workflows/_binary-test-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.12"
build_name: conda-py3_12-cuda12_4
build_environment: linux-binary-conda
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-py3_12-cuda12_4-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: conda-py3_12-cuda12_4-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: conda
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/conda-builder:cuda12.4-main
DESIRED_PYTHON: "3.12"
build_name: conda-py3_12-cuda12_4
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml

View File

@ -26,7 +26,7 @@ env:
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: linux-binary-libtorch-cxx11-abi-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

View File

@ -31,7 +31,7 @@ env:
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: linux-binary-libtorch-cxx11-abi-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
@ -229,7 +229,7 @@ jobs:
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda12_4-shared-with-deps-cxx11-abi-build:
libtorch-rocm5_7-shared-with-deps-cxx11-abi-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
@ -238,56 +238,97 @@ jobs:
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-main
DESIRED_CUDA: rocm5.7
GPU_ARCH_VERSION: 5.7
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.7-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: cxx11-abi
build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi
build_name: libtorch-rocm5_7-shared-with-deps-cxx11-abi
build_environment: linux-binary-libtorch-cxx11-abi
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-cuda12_4-shared-with-deps-cxx11-abi-test: # Testing
libtorch-rocm5_7-shared-with-deps-cxx11-abi-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: libtorch-cuda12_4-shared-with-deps-cxx11-abi-build
uses: ./.github/workflows/_binary-test-linux.yml
with:
needs: libtorch-rocm5_7-shared-with-deps-cxx11-abi-build
runs-on: linux.rocm.gpu
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-main
DESIRED_CUDA: rocm5.7
GPU_ARCH_VERSION: 5.7
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.7-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: cxx11-abi
build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi
build_environment: linux-binary-libtorch-cxx11-abi
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-cuda12_4-shared-with-deps-cxx11-abi-upload: # Uploading
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v3
name: Download Build Artifacts
with:
name: libtorch-rocm5_7-shared-with-deps-cxx11-abi
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: pytorch/libtorch-cxx11-builder:rocm5.7-main
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm5_7-shared-with-deps-cxx11-abi-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-cuda12_4-shared-with-deps-cxx11-abi-test
needs: libtorch-rocm5_7-shared-with-deps-cxx11-abi-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cuda12.4-main
DESIRED_CUDA: rocm5.7
GPU_ARCH_VERSION: 5.7
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm5.7-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: cxx11-abi
build_name: libtorch-cuda12_4-shared-with-deps-cxx11-abi
build_name: libtorch-rocm5_7-shared-with-deps-cxx11-abi
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
@ -399,109 +440,3 @@ jobs:
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm6_1-shared-with-deps-cxx11-abi-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: cxx11-abi
build_name: libtorch-rocm6_1-shared-with-deps-cxx11-abi
build_environment: linux-binary-libtorch-cxx11-abi
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-rocm6_1-shared-with-deps-cxx11-abi-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: libtorch-rocm6_1-shared-with-deps-cxx11-abi-build
runs-on: linux.rocm.gpu
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: cxx11-abi
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v3
name: Download Build Artifacts
with:
name: libtorch-rocm6_1-shared-with-deps-cxx11-abi
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: pytorch/libtorch-cxx11-builder:rocm6.1-main
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm6_1-shared-with-deps-cxx11-abi-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-rocm6_1-shared-with-deps-cxx11-abi-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:rocm6.1-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: cxx11-abi
build_name: libtorch-rocm6_1-shared-with-deps-cxx11-abi
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml

View File

@ -26,7 +26,7 @@ env:
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: linux-binary-libtorch-pre-cxx11-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

View File

@ -31,7 +31,7 @@ env:
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: linux-binary-libtorch-pre-cxx11-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true
@ -229,7 +229,7 @@ jobs:
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda12_4-shared-with-deps-pre-cxx11-build:
libtorch-rocm5_7-shared-with-deps-pre-cxx11-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
@ -238,56 +238,97 @@ jobs:
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
DESIRED_CUDA: rocm5.7
GPU_ARCH_VERSION: 5.7
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.7-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: pre-cxx11
build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11
build_name: libtorch-rocm5_7-shared-with-deps-pre-cxx11
build_environment: linux-binary-libtorch-pre-cxx11
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-cuda12_4-shared-with-deps-pre-cxx11-test: # Testing
libtorch-rocm5_7-shared-with-deps-pre-cxx11-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: libtorch-cuda12_4-shared-with-deps-pre-cxx11-build
uses: ./.github/workflows/_binary-test-linux.yml
with:
needs: libtorch-rocm5_7-shared-with-deps-pre-cxx11-build
runs-on: linux.rocm.gpu
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
DESIRED_CUDA: rocm5.7
GPU_ARCH_VERSION: 5.7
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.7-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: pre-cxx11
build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11
build_environment: linux-binary-libtorch-pre-cxx11
runs_on: linux.4xlarge.nvidia.gpu
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-cuda12_4-shared-with-deps-pre-cxx11-upload: # Uploading
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v3
name: Download Build Artifacts
with:
name: libtorch-rocm5_7-shared-with-deps-pre-cxx11
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: pytorch/manylinux-builder:rocm5.7-main
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm5_7-shared-with-deps-pre-cxx11-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-cuda12_4-shared-with-deps-pre-cxx11-test
needs: libtorch-rocm5_7-shared-with-deps-pre-cxx11-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main
DESIRED_CUDA: rocm5.7
GPU_ARCH_VERSION: 5.7
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/manylinux-builder:rocm5.7-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: pre-cxx11
build_name: libtorch-cuda12_4-shared-with-deps-pre-cxx11
build_name: libtorch-rocm5_7-shared-with-deps-pre-cxx11
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
@ -399,109 +440,3 @@ jobs:
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-rocm6_1-shared-with-deps-pre-cxx11-build:
if: ${{ github.repository_owner == 'pytorch' }}
uses: ./.github/workflows/_binary-build-linux.yml
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: pre-cxx11
build_name: libtorch-rocm6_1-shared-with-deps-pre-cxx11
build_environment: linux-binary-libtorch-pre-cxx11
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
libtorch-rocm6_1-shared-with-deps-pre-cxx11-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: libtorch-rocm6_1-shared-with-deps-pre-cxx11-build
runs-on: linux.rocm.gpu
timeout-minutes: 240
env:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
SKIP_ALL_TESTS: 1
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: pre-cxx11
steps:
- name: Setup ROCm
uses: ./.github/actions/setup-rocm
- uses: actions/download-artifact@v3
name: Download Build Artifacts
with:
name: libtorch-rocm6_1-shared-with-deps-pre-cxx11
path: "${{ runner.temp }}/artifacts/"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: ROCm set GPU_FLAG
run: |
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon" >> "${GITHUB_ENV}"
- name: Pull Docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
with:
docker-image: pytorch/manylinux-builder:rocm6.1-main
- name: Test Pytorch binary
uses: ./pytorch/.github/actions/test-pytorch-binary
- name: Teardown ROCm
uses: ./.github/actions/teardown-rocm
libtorch-rocm6_1-shared-with-deps-pre-cxx11-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-rocm6_1-shared-with-deps-pre-cxx11-test
with:
PYTORCH_ROOT: /pytorch
BUILDER_ROOT: /builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: rocm6.1
GPU_ARCH_VERSION: 6.1
GPU_ARCH_TYPE: rocm
DOCKER_IMAGE: pytorch/manylinux-builder:rocm6.1-main
LIBTORCH_VARIANT: shared-with-deps
DESIRED_DEVTOOLSET: pre-cxx11
build_name: libtorch-rocm6_1-shared-with-deps-pre-cxx11
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml

View File

@ -26,7 +26,7 @@ env:
PYTORCH_FINAL_PACKAGE_DIR: /artifacts
PYTORCH_ROOT: /pytorch
SHA1: ${{ github.event.pull_request.head.sha || github.sha }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: linux-binary-manywheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

File diff suppressed because it is too large Load Diff

View File

@ -26,7 +26,7 @@ env:
BUILD_ENVIRONMENT: macos-arm64-binary-conda
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: macos-arm64-binary-conda-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

View File

@ -26,7 +26,7 @@ env:
BUILD_ENVIRONMENT: macos-arm64-binary-libtorch-cxx11-abi
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: macos-arm64-binary-libtorch-cxx11-abi-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

View File

@ -26,7 +26,7 @@ env:
BUILD_ENVIRONMENT: macos-arm64-binary-wheel
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PR_NUMBER: ${{ github.event.pull_request.number }}
SKIP_ALL_TESTS: 0
SKIP_ALL_TESTS: 1
concurrency:
group: macos-arm64-binary-wheel-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
cancel-in-progress: true

File diff suppressed because it is too large Load Diff

View File

@ -800,260 +800,3 @@ jobs:
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda12_4-shared-with-deps-debug-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: windows.4xlarge.nonephemeral
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
BUILDER_ROOT: ${{ github.workspace }}/builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v3
if: always()
with:
name: libtorch-cuda12_4-shared-with-deps-debug
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_4-shared-with-deps-debug-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: libtorch-cuda12_4-shared-with-deps-debug-build
runs-on: windows.8xlarge.nvidia.gpu
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
BUILDER_ROOT: ${{ github.workspace }}/builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v3
name: Download Build Artifacts
with:
name: libtorch-cuda12_4-shared-with-deps-debug
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_4-shared-with-deps-debug-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-cuda12_4-shared-with-deps-debug-test
with:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
BUILDER_ROOT: ${{ github.workspace }}/builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
LIBTORCH_CONFIG: debug
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
build_name: libtorch-cuda12_4-shared-with-deps-debug
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml

View File

@ -800,260 +800,3 @@ jobs:
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml
libtorch-cuda12_4-shared-with-deps-release-build:
if: ${{ github.repository_owner == 'pytorch' }}
runs-on: windows.4xlarge.nonephemeral
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
BUILDER_ROOT: ${{ github.workspace }}/builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Build PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_build.sh"
- uses: actions/upload-artifact@v3
if: always()
with:
name: libtorch-cuda12_4-shared-with-deps-release
retention-days: 14
if-no-files-found: error
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_4-shared-with-deps-release-test: # Testing
if: ${{ github.repository_owner == 'pytorch' }}
needs: libtorch-cuda12_4-shared-with-deps-release-build
runs-on: windows.8xlarge.nvidia.gpu
timeout-minutes: 240
env:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
BUILDER_ROOT: ${{ github.workspace }}/builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
SKIP_ALL_TESTS: 1
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
steps:
- name: Display EC2 information
shell: bash
run: |
set -euo pipefail
function get_ec2_metadata() {
# Pulled from instance metadata endpoint for EC2
# see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html
category=$1
curl -fsSL "http://169.254.169.254/latest/meta-data/${category}"
}
echo "ami-id: $(get_ec2_metadata ami-id)"
echo "instance-id: $(get_ec2_metadata instance-id)"
echo "instance-type: $(get_ec2_metadata instance-type)"
echo "system info $(uname -a)"
- name: "[FB EMPLOYEES] Enable SSH (Click me for login details)"
uses: pytorch/test-infra/.github/actions/setup-ssh@main
continue-on-error: true
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
# Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560
- name: Enable long paths on Windows
shell: powershell
run: |
Set-ItemProperty -Path "HKLM:\\SYSTEM\CurrentControlSet\Control\FileSystem" -Name "LongPathsEnabled" -Value 1
# Since it's just a defensive command, the workflow should continue even the command fails. This step can be
# removed once Windows Defender is removed from the AMI
- name: Disables Windows Defender scheduled and real-time scanning for files in directories used by PyTorch
continue-on-error: true
shell: powershell
run: |
Add-MpPreference -ExclusionPath $(Get-Location).tostring(),$Env:TEMP -ErrorAction Ignore
# Let's both exclude the path and disable Windows Defender completely just to be sure
# that it doesn't interfere
Set-MpPreference -DisableRealtimeMonitoring $True -ErrorAction Ignore
# NOTE: These environment variables are put here so that they can be applied on every job equally
# They are also here because setting them at a workflow level doesn't give us access to the
# runner.temp variable, which we need.
- name: Populate binary env
shell: bash
run: |
echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
echo "WIN_PACKAGE_WORK_DIR=${RUNNER_TEMP}"
- uses: actions/download-artifact@v3
name: Download Build Artifacts
with:
name: libtorch-cuda12_4-shared-with-deps-release
path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
- name: Checkout PyTorch
uses: malfet/checkout@silent-checkout
with:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
submodules: recursive
path: pytorch
quiet-checkout: true
- name: Clean PyTorch checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: pytorch
- name: Checkout pytorch/builder
uses: malfet/checkout@silent-checkout
with:
ref: main
submodules: recursive
repository: pytorch/builder
path: builder
quiet-checkout: true
- name: Clean pytorch/builder checkout
run: |
# Remove any artifacts from the previous checkouts
git clean -fxd
working-directory: builder
- name: Populate binary env
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
- name: Test PyTorch binary
shell: bash
run: |
"${PYTORCH_ROOT}/.circleci/scripts/binary_windows_test.sh"
- name: Wait until all sessions have drained
shell: powershell
working-directory: pytorch
if: always()
timeout-minutes: 120
run: |
.github\scripts\wait_for_ssh_to_drain.ps1
- name: Kill active ssh sessions if still around (Useful if workflow was cancelled)
shell: powershell
working-directory: pytorch
if: always()
run: |
.github\scripts\kill_active_ssh_sessions.ps1
libtorch-cuda12_4-shared-with-deps-release-upload: # Uploading
if: ${{ github.repository_owner == 'pytorch' }}
permissions:
id-token: write
contents: read
needs: libtorch-cuda12_4-shared-with-deps-release-test
with:
PYTORCH_ROOT: ${{ github.workspace }}/pytorch
BUILDER_ROOT: ${{ github.workspace }}/builder
PACKAGE_TYPE: libtorch
# TODO: This is a legacy variable that we eventually want to get rid of in
# favor of GPU_ARCH_VERSION
DESIRED_CUDA: cu124
GPU_ARCH_VERSION: 12.4
GPU_ARCH_TYPE: cuda
LIBTORCH_CONFIG: release
LIBTORCH_VARIANT: shared-with-deps
# This is a dummy value for libtorch to work correctly with our batch scripts
# without this value pip does not get installed for some reason
DESIRED_PYTHON: "3.8"
build_name: libtorch-cuda12_4-shared-with-deps-release
secrets:
github-token: ${{ secrets.GITHUB_TOKEN }}
conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }}
conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }}
uses: ./.github/workflows/_binary-upload.yml

File diff suppressed because it is too large Load Diff

View File

@ -1,40 +0,0 @@
name: inductor-micro-benchmark
on:
schedule:
- cron: 0 7 * * *
push:
tags:
- ciflow/inductor-micro-benchmark/*
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions: read-all
jobs:
linux-focal-cuda12_1-py3_10-gcc9-inductor-micro-benchmark-build:
name: cuda12.1-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80
docker-image-name: pytorch-linux-focal-cuda12.1-cudnn8-py3-gcc9-inductor-benchmarks
cuda-arch-list: '8.0'
test-matrix: |
{ include: [
{ config: "inductor-micro-benchmark", shard: 1, num_shards: 1, runner: "linux.gcp.a100" },
]}
linux-focal-cuda12_1-py3_10-gcc9-inductor-micro-benchmark-test:
name: cuda12.1-py3.10-gcc9-sm80
uses: ./.github/workflows/_linux-test.yml
needs: linux-focal-cuda12_1-py3_10-gcc9-inductor-micro-benchmark-build
with:
build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80
docker-image: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-inductor-micro-benchmark-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_1-py3_10-gcc9-inductor-micro-benchmark-build.outputs.test-matrix }}
use-gha: anything-non-empty-to-use-gha
timeout-minutes: 720

View File

@ -16,28 +16,28 @@ concurrency:
permissions: read-all
jobs:
linux-focal-rocm6_1-py3_8-inductor-build:
name: rocm6.1-py3.8-inductor
linux-focal-rocm6_0-py3_8-inductor-build:
name: rocm6.0-py3.8-inductor
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-focal-rocm6.1-py3.8
build-environment: linux-focal-rocm6.0-py3.8
docker-image-name: pytorch-linux-focal-rocm-n-py3
test-matrix: |
{ include: [
{ config: "inductor", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.2" },
]}
linux-focal-rocm6_1-py3_8-inductor-test:
linux-focal-rocm6_0-py3_8-inductor-test:
permissions:
id-token: write
contents: read
name: rocm6.1-py3.8-inductor
name: rocm6.0-py3.8-inductor
uses: ./.github/workflows/_rocm-test.yml
needs: linux-focal-rocm6_1-py3_8-inductor-build
needs: linux-focal-rocm6_0-py3_8-inductor-build
with:
build-environment: linux-focal-rocm6.1-py3.8
docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-inductor-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.0-py3.8
docker-image: ${{ needs.linux-focal-rocm6_0-py3_8-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_0-py3_8-inductor-build.outputs.test-matrix }}
linux-focal-cuda12_1-py3_10-gcc9-inductor-build:
name: cuda12.1-py3.10-gcc9-sm86

View File

@ -1,39 +0,0 @@
name: linux-aarch64
on:
push:
tags:
- ciflow/linux-aarch64/*
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} but found ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
jobs:
linux-jammy-aarch64-py3_10-build:
name: linux-jammy-aarch64-py3.10
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-jammy-aarch64-py3.10
docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
runner: linux.arm64.2xlarge
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 4, runner: "linux.arm64.2xlarge" },
{ config: "default", shard: 2, num_shards: 4, runner: "linux.arm64.2xlarge" },
{ config: "default", shard: 3, num_shards: 4, runner: "linux.arm64.2xlarge" },
{ config: "default", shard: 4, num_shards: 4, runner: "linux.arm64.2xlarge" },
]}
linux-jammy-aarch64-py3_10-test:
name: linux-jammy-aarch64-py3.10
uses: ./.github/workflows/_linux-test.yml
needs: linux-jammy-aarch64-py3_10-build
permissions:
id-token: write
contents: read
with:
build-environment: linux-jammy-aarch64-py3.10
docker-image: ${{ needs.linux-jammy-aarch64-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-build.outputs.test-matrix }}

View File

@ -217,11 +217,11 @@ jobs:
docker-image: ${{ needs.linux-vulkan-focal-py3_11-clang10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-vulkan-focal-py3_11-clang10-build.outputs.test-matrix }}
linux-focal-rocm6_1-py3_8-build:
name: linux-focal-rocm6.1-py3.8
linux-focal-rocm6_0-py3_8-build:
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-focal-rocm6.1-py3.8
build-environment: linux-focal-rocm6.0-py3.8
docker-image-name: pytorch-linux-focal-rocm-n-py3
test-matrix: |
{ include: [
@ -229,16 +229,16 @@ jobs:
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" },
]}
linux-focal-rocm6_1-py3_8-test:
linux-focal-rocm6_0-py3_8-test:
permissions:
id-token: write
contents: read
name: linux-focal-rocm6.1-py3.8
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-focal-rocm6_1-py3_8-build
- linux-focal-rocm6_0-py3_8-build
- target-determination
with:
build-environment: linux-focal-rocm6.1-py3.8
docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.0-py3.8
docker-image: ${{ needs.linux-focal-rocm6_0-py3_8-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_0-py3_8-build.outputs.test-matrix }}

View File

@ -414,13 +414,13 @@ jobs:
{ config: "default", shard: 1, num_shards: 1 },
]}
linux-focal-rocm6_1-py3_8-build:
linux-focal-rocm6_0-py3_8-build:
# don't run build twice on main
if: github.event_name == 'pull_request'
name: linux-focal-rocm6.1-py3.8
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_linux-build-label.yml
with:
build-environment: linux-focal-rocm6.1-py3.8
build-environment: linux-focal-rocm6.0-py3.8
docker-image-name: pytorch-linux-focal-rocm-n-py3
sync-tag: rocm-build
test-matrix: |

View File

@ -25,11 +25,11 @@ jobs:
id-token: write
contents: read
linux-focal-rocm6_1-py3_8-build:
name: linux-focal-rocm6.1-py3.8
linux-focal-rocm6_0-py3_8-build:
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_linux-build-label.yml
with:
build-environment: linux-focal-rocm6.1-py3.8
build-environment: linux-focal-rocm6.0-py3.8
docker-image-name: pytorch-linux-focal-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
@ -42,16 +42,16 @@ jobs:
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
]}
linux-focal-rocm6_1-py3_8-test:
linux-focal-rocm6_0-py3_8-test:
permissions:
id-token: write
contents: read
name: linux-focal-rocm6.1-py3.8
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-focal-rocm6_1-py3_8-build
- linux-focal-rocm6_0-py3_8-build
- target-determination
with:
build-environment: linux-focal-rocm6.1-py3.8
docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.0-py3.8
docker-image: ${{ needs.linux-focal-rocm6_0-py3_8-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_0-py3_8-build.outputs.test-matrix }}

View File

@ -111,30 +111,30 @@ jobs:
docker-image: ${{ needs.linux-focal-py3_8-clang10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-py3_8-clang10-build.outputs.test-matrix }}
linux-focal-rocm6_1-py3_8-build:
name: linux-focal-rocm6.1-py3.8
linux-focal-rocm6_0-py3_8-build:
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-focal-rocm6.1-py3.8
build-environment: linux-focal-rocm6.0-py3.8
docker-image-name: pytorch-linux-focal-rocm-n-py3
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" },
]}
linux-focal-rocm6_1-py3_8-test:
linux-focal-rocm6_0-py3_8-test:
permissions:
id-token: write
contents: read
name: linux-focal-rocm6.1-py3.8
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-focal-rocm6_1-py3_8-build
- linux-focal-rocm6_0-py3_8-build
- target-determination
with:
build-environment: linux-focal-rocm6.1-py3.8
docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.0-py3.8
docker-image: ${{ needs.linux-focal-rocm6_0-py3_8-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_0-py3_8-build.outputs.test-matrix }}
linux-jammy-py3_10-clang15-asan-build:
name: linux-jammy-py3.10-clang15-asan
@ -144,9 +144,8 @@ jobs:
docker-image-name: pytorch-linux-jammy-py3-clang15-asan
test-matrix: |
{ include: [
{ config: "slow", shard: 1, num_shards: 3, runner: "linux.4xlarge" },
{ config: "slow", shard: 2, num_shards: 3, runner: "linux.4xlarge" },
{ config: "slow", shard: 3, num_shards: 3, runner: "linux.4xlarge" },
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.4xlarge" },
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.4xlarge" },
]}
sync-tag: asan-build

View File

@ -198,11 +198,11 @@ jobs:
{ config: "force_on_cpu", shard: 1, num_shards: 1, runner: "windows.4xlarge.nonephemeral" },
]}
linux-focal-rocm6_1-py3_8-build:
name: linux-focal-rocm6.1-py3.8
linux-focal-rocm6_0-py3_8-build:
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_linux-build-label.yml
with:
build-environment: linux-focal-rocm6.1-py3.8
build-environment: linux-focal-rocm6.0-py3.8
docker-image-name: pytorch-linux-focal-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
@ -210,17 +210,17 @@ jobs:
{ config: "default", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" },
]}
linux-focal-rocm6_1-py3_8-test:
linux-focal-rocm6_0-py3_8-test:
permissions:
id-token: write
contents: read
name: linux-focal-rocm6.1-py3.8
name: linux-focal-rocm6.0-py3.8
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-focal-rocm6_1-py3_8-build
- linux-focal-rocm6_0-py3_8-build
- target-determination
with:
build-environment: linux-focal-rocm6.1-py3.8
docker-image: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_1-py3_8-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.0-py3.8
docker-image: ${{ needs.linux-focal-rocm6_0-py3_8-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_0-py3_8-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor"

View File

@ -36,20 +36,6 @@ jobs:
#
# Experimental ARC jobs
#
llm-td:
name: before-test
uses: ./.github/workflows/llm_td_retrieval.yml
permissions:
id-token: write
contents: read
target-determination:
name: before-test
uses: ./.github/workflows/target_determination.yml
needs: llm-td
permissions:
id-token: write
contents: read
linux-jammy-py3_8-gcc11-build:
name: linux-jammy-py3.8-gcc11
@ -59,26 +45,16 @@ jobs:
docker-image-name: pytorch-linux-jammy-py3.8-gcc11
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "docs_test", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "distributed", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 1, num_shards: 3, runner: "linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 3, runner: "linux.2xlarge" },
{ config: "default", shard: 3, num_shards: 3, runner: "linux.2xlarge" },
{ config: "docs_test", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "linux.2xlarge" },
{ config: "distributed", shard: 1, num_shards: 2, runner: "linux.2xlarge" },
{ config: "distributed", shard: 2, num_shards: 2, runner: "linux.2xlarge" },
]}
linux-jammy-py3_8-gcc11-test:
name: linux-jammy-py3.8-gcc11
uses: ./.github/workflows/_linux-test-rg.yml
needs:
- linux-jammy-py3_8-gcc11-build
- target-determination
with:
build-environment: linux-jammy-py3.8-gcc11
docker-image: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-py3_8-gcc11-build.outputs.test-matrix }}
linux-jammy-py3_8-gcc11-no-ops:
name: linux-jammy-py3.8-gcc11-no-ops
@ -110,21 +86,10 @@ jobs:
docker-image-name: pytorch-linux-focal-py3-clang10-onnx
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 1, num_shards: 2, runner: "linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.2xlarge" },
]}
linux-focal-py3_8-clang10-onnx-test:
name: linux-focal-py3.8-clang10-onnx
uses: ./.github/workflows/_linux-test-rg.yml
needs:
- linux-focal-py3_8-clang10-onnx-build
- target-determination
with:
build-environment: linux-focal-py3.8-clang10-onnx
docker-image: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-py3_8-clang10-onnx-build.outputs.test-matrix }}
linux-jammy-py3_10-clang15-asan-build:
name: linux-jammy-py3.10-clang15-asan
uses: ./.github/workflows/_linux-build-rg.yml
@ -150,27 +115,16 @@ jobs:
docker-image-name: pytorch-linux-focal-py3.8-clang10
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "crossref", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "crossref", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "dynamo", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "dynamo", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "dynamo", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 1, num_shards: 3, runner: "linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 3, runner: "linux.2xlarge" },
{ config: "default", shard: 3, num_shards: 3, runner: "linux.2xlarge" },
{ config: "crossref", shard: 1, num_shards: 2, runner: "linux.2xlarge" },
{ config: "crossref", shard: 2, num_shards: 2, runner: "linux.2xlarge" },
{ config: "dynamo", shard: 1, num_shards: 3, runner: "linux.2xlarge" },
{ config: "dynamo", shard: 2, num_shards: 3, runner: "linux.2xlarge" },
{ config: "dynamo", shard: 3, num_shards: 3, runner: "linux.2xlarge" },
]}
linux-focal-py3_8-clang10-test:
name: linux-focal-py3.8-clang10
uses: ./.github/workflows/_linux-test-rg.yml
needs:
- linux-focal-py3_8-clang10-build
- target-determination
with:
build-environment: linux-focal-py3.8-clang10
docker-image: ${{ needs.linux-focal-py3_8-clang10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-py3_8-clang10-build.outputs.test-matrix }}
linux-focal-py3_11-clang10-build:
name: linux-focal-py3.11-clang10
uses: ./.github/workflows/_linux-build-rg.yml
@ -179,27 +133,16 @@ jobs:
docker-image-name: pytorch-linux-focal-py3.11-clang10
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "crossref", shard: 1, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "crossref", shard: 2, num_shards: 2, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "dynamo", shard: 1, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "dynamo", shard: 2, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "dynamo", shard: 3, num_shards: 3, runner: "arc-lf-linux.2xlarge.avx512" },
{ config: "default", shard: 1, num_shards: 3, runner: "linux.2xlarge" },
{ config: "default", shard: 2, num_shards: 3, runner: "linux.2xlarge" },
{ config: "default", shard: 3, num_shards: 3, runner: "linux.2xlarge" },
{ config: "crossref", shard: 1, num_shards: 2, runner: "linux.2xlarge" },
{ config: "crossref", shard: 2, num_shards: 2, runner: "linux.2xlarge" },
{ config: "dynamo", shard: 1, num_shards: 3, runner: "linux.2xlarge" },
{ config: "dynamo", shard: 2, num_shards: 3, runner: "linux.2xlarge" },
{ config: "dynamo", shard: 3, num_shards: 3, runner: "linux.2xlarge" },
]}
linux-focal-py3_11-clang10-test:
name: linux-focal-py3.11-clang10
uses: ./.github/workflows/_linux-test-rg.yml
needs:
- linux-focal-py3_11-clang10-build
- target-determination
with:
build-environment: linux-focal-py3.11-clang10
docker-image: ${{ needs.linux-focal-py3_11-clang10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-py3_11-clang10-build.outputs.test-matrix }}
#
# End of Experimental ARC jobs
#
#

View File

@ -1051,12 +1051,21 @@ exclude_patterns = [
'test/quantization/fx/test_numeric_suite_fx.py',
'test/quantization/fx/test_quantize_fx.py',
'test/quantization/fx/test_subgraph_rewriter.py',
'test/test_custom_op_testing.py',
'test/test_dataloader.py',
'test/test_datapipe.py',
'test/test_decomp.py',
'test/test_deploy.py',
'test/test_determination.py',
'test/test_dlpack.py',
'test/test_dynamic_shapes.py',
'test/test_expanded_weights.py',
'test/test_fake_tensor.py',
'test/test_flop_counter.py',
'test/test_function_schema.py',
'test/test_functional_autograd_benchmark.py',
'test/test_functional_optim.py',
'test/test_functionalization.py',
'test/test_functionalization_of_rng_ops.py',
'test/test_futures.py',
'test/test_fx.py',
@ -1065,6 +1074,7 @@ exclude_patterns = [
'test/test_fx_reinplace_pass.py',
'test/test_hub.py',
'test/test_import_stats.py',
'test/test_indexing.py',
'test/test_itt.py',
'test/test_jit.py',
'test/test_jit_autocast.py',
@ -1150,6 +1160,7 @@ exclude_patterns = [
'test/test_type_promotion.py',
'test/test_unary_ufuncs.py',
'test/test_utils.py',
'test/test_view_ops.py',
'test/test_vulkan.py',
'test/test_xnnpack_integration.py',
'test/torch_np/numpy_test/**/*.py',

View File

@ -446,6 +446,7 @@ cu_library(
# caffe2
CAFFE2_COPTS = COMMON_COPTS + [
"-Dcaffe2_EXPORTS",
"-DCAFFE2_USE_GLOO",
"-DCAFFE2_USE_CUDNN",
"-DCAFFE2_BUILD_MAIN_LIB",
"-fvisibility-inlines-hidden",
@ -453,6 +454,22 @@ CAFFE2_COPTS = COMMON_COPTS + [
"-fno-trapping-math",
]
filegroup(
name = "caffe2_contrib_srcs",
srcs = [
"caffe2/contrib/aten/aten_op.cc",
"caffe2/contrib/gloo/allgather_ops.cc",
"caffe2/contrib/gloo/allreduce_ops.cc",
"caffe2/contrib/gloo/barrier_ops.cc",
"caffe2/contrib/gloo/broadcast_ops.cc",
"caffe2/contrib/gloo/common.cc",
"caffe2/contrib/gloo/common_world_ops.cc",
"caffe2/contrib/gloo/context.cc",
"caffe2/contrib/gloo/reduce_scatter_ops.cc",
"caffe2/contrib/gloo/store_handler.cc",
],
)
filegroup(
name = "caffe2_core_srcs",
srcs = [
@ -503,6 +520,16 @@ filegroup(
],
)
filegroup(
name = "caffe2_distributed_srcs",
srcs = [
"caffe2/distributed/file_store_handler.cc",
"caffe2/distributed/file_store_handler_op.cc",
"caffe2/distributed/store_handler.cc",
"caffe2/distributed/store_ops.cc",
],
)
filegroup(
name = "caffe2_ideep_srcs",
srcs = [
@ -997,10 +1024,16 @@ filegroup(
filegroup(
name = "caffe2_cuda_cpp_srcs",
srcs = [
"caffe2/contrib/aten/aten_op_gpu.cc",
"caffe2/contrib/gloo/allreduce_ops_gpu.cc",
"caffe2/contrib/gloo/broadcast_ops_gpu.cc",
"caffe2/contrib/gloo/common_world_ops_gpu.cc",
"caffe2/core/blob_serialization_gpu.cc",
"caffe2/core/common_cudnn.cc",
"caffe2/core/common_gpu.cc",
"caffe2/core/event_gpu.cc",
"caffe2/db/create_db_op_gpu.cc",
"caffe2/distributed/file_store_handler_op_gpu.cc",
"caffe2/operators/communicator_op_gpu.cc",
"caffe2/operators/concat_split_op_gpu.cc",
"caffe2/operators/conv_op_cache_cudnn.cc",
@ -1238,10 +1271,35 @@ cc_library(
],
)
py_binary(
name = "gen_op",
srcs = ["caffe2/contrib/aten/gen_op.py"],
deps = ["//torchgen"],
)
genrule(
name = "generated_caffe2_aten_op_headers",
srcs = [
"caffe2/contrib/aten/aten_op_template.h",
"aten/src/ATen/Declarations.yaml",
],
outs = ["caffe2/caffe2/contrib/aten/gen_aten_op.h"],
cmd = """
$(location :gen_op) \
--output_prefix gen_ \
--install_dir $(@D) \
--aten_root `dirname $(location aten/src/ATen/Declarations.yaml)`/../.. \
--template_dir `dirname $(location caffe2/contrib/aten/aten_op_template.h)` \
--yaml_dir `dirname $(location aten/src/ATen/Declarations.yaml)`""",
tools = [":gen_op"],
)
cc_library(
name = "caffe2_headers",
hdrs = glob(
[
"caffe2/contrib/aten/*.h",
"caffe2/contrib/gloo/*.h",
"caffe2/core/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Converters/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Generated/*.h",
@ -1250,6 +1308,8 @@ cc_library(
"caffe2/core/nomnigraph/include/nomnigraph/Support/*.h",
"caffe2/core/nomnigraph/include/nomnigraph/Transformations/*.h",
"caffe2/core/nomnigraph/tests/*.h",
"caffe2/db/*.h",
"caffe2/distributed/*.h",
"caffe2/ideep/*.h",
"caffe2/ideep/operators/*.h",
"caffe2/ideep/operators/quantization/*.h",
@ -1278,9 +1338,10 @@ cc_library(
) + if_cuda(glob([
"caffe2/**/*.cuh",
"caffe2/image/*.h",
])),
])) + [":generated_caffe2_aten_op_headers"],
copts = CAFFE2_COPTS,
includes = [
"caffe2/contrib/aten",
"caffe2/core/nomnigraph/include",
],
visibility = ["//visibility:public"],
@ -1321,8 +1382,12 @@ cc_library(
cc_library(
name = "caffe2",
srcs = [
"caffe2/db/create_db_op.cc",
"caffe2/db/protodb.cc",
"caffe2/share/contrib/depthwise/depthwise3x3_conv_op.cc",
":caffe2_contrib_srcs",
":caffe2_core_srcs",
":caffe2_distributed_srcs",
":caffe2_ideep_srcs",
":caffe2_onnx_srcs",
":caffe2_operators_srcs",
@ -1354,6 +1419,7 @@ cc_library(
"@fbgemm//:fbgemm_src_headers",
"@fmt",
"@foxi",
"@gloo",
"@onnx",
] + if_cuda(
[
@ -1401,6 +1467,7 @@ cu_library(
"@cuda//:curand",
"@cudnn",
"@eigen",
"@gloo",
"@tensorpipe//:tensorpipe_cuda",
],
alwayslink = True,

View File

@ -56,7 +56,7 @@ endif()
# This define is needed to preserve behavior given anticpated changes to cccl/thrust
# https://nvidia.github.io/libcudacxx/standard_api/numerics_library/complex.html
string(APPEND CMAKE_CUDA_FLAGS " -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS")
string(APPEND CMAKE_CUDA_FLAGS "-DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS")
if(LINUX)
include(cmake/CheckAbi.cmake)
@ -228,9 +228,12 @@ option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
option(USE_FFMPEG "Use ffmpeg" OFF)
option(USE_GFLAGS "Use GFLAGS" OFF)
option(USE_GLOG "Use GLOG" OFF)
option(USE_LEVELDB "Use LEVELDB" OFF)
option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
option(USE_LMDB "Use LMDB" OFF)
option(USE_MAGMA "Use MAGMA" ON)
option(USE_METAL "Use Metal for Caffe2 iOS build" ON)
option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF)
@ -261,12 +264,15 @@ cmake_dependent_option(
option(USE_NUMPY "Use NumPy" ON)
option(USE_OBSERVERS "Use observers module." OFF)
option(USE_OPENCL "Use OpenCL" OFF)
option(USE_OPENCV "Use OpenCV" OFF)
option(USE_OPENMP "Use OpenMP for parallel code" ON)
option(USE_PRECOMPILED_HEADERS "Use pre-compiled headers to accelerate build." OFF)
option(USE_PROF "Use profiling" OFF)
option(USE_QNNPACK "Use QNNPACK (quantized 8-bit operators)" ON)
option(USE_PYTORCH_QNNPACK "Use ATen/QNNPACK (quantized 8-bit operators)" ON)
option(USE_REDIS "Use Redis" OFF)
option(USE_ROCKSDB "Use RocksDB" OFF)
option(USE_SNPE "Use Qualcomm's SNPE library" OFF)
option(USE_SYSTEM_EIGEN_INSTALL
"Use system Eigen instead of the one under third_party" OFF)
@ -288,6 +294,7 @@ option(USE_VULKAN_FP16_INFERENCE "Vulkan - Use fp16 inference" OFF)
option(USE_VULKAN_RELAXED_PRECISION "Vulkan - Use relaxed precision math in the kernels (mediump)" OFF)
# option USE_XNNPACK: try to enable xnnpack by default.
option(USE_XNNPACK "Use XNNPACK" ON)
option(USE_ZMQ "Use ZMQ" OFF)
option(USE_ZSTD "Use ZSTD" OFF)
option(USE_ROCM_KERNEL_ASSERT "Use Kernel Assert for ROCm" OFF)
# Ensure that an ITT build is the default for x86 CPUs

View File

@ -116,7 +116,7 @@ torch/profiler/ @aaronenyeshi
test/functorch/test_aotdispatch.py @ezyang @Chillee
# Dataloader
torch/utils/data/ @andrewkho @gokulavasan
torch/utils/data/ @ejguan
# hipify
torch/utils/hipify/ @jeffdaily @jithunnair-amd
@ -144,14 +144,3 @@ caffe2/utils/hip @jeffdaily @jithunnair-amd
/torch/csrc/Storage* @mikaylagawarecki
# subscribing for PyTorchFileWriter/PyTorchFileReader changes
/torch/csrc/jit/python/init.cpp @mikaylagawarecki
# CUDA and CUDA math libraries
aten/src/ATen/cuda/ @eqy
aten/src/ATen/cudnn/ @eqy
aten/src/ATen/native/cuda/ @eqy
aten/src/ATen/native/cudnn/ @eqy
c10/cuda @eqy
torch/cuda/ @eqy
torch/csrc/cuda/ @eqy
torch/backends/cuda/ @eqy
torch/backends/cudnn/ @eqy

View File

@ -1,10 +1,12 @@
# syntax=docker/dockerfile:1
# NOTE: Building this image require's docker version >= 23.0.
# syntax = docker/dockerfile:experimental
#
# For reference:
# - https://docs.docker.com/build/dockerfile/frontend/#stable-channel
# NOTE: To build this you will need a docker version > 18.06 with
# experimental enabled and DOCKER_BUILDKIT=1
#
# If you do not use buildkit you are not going to have a good time
#
# For reference:
# https://docs.docker.com/develop/develop-images/build_enhancements/
ARG BASE_IMAGE=ubuntu:22.04
ARG PYTHON_VERSION=3.11

View File

@ -268,12 +268,6 @@ at::BlasBackend Context::blasPreferredBackend() const {
}
void Context::setBlasPreferredBackend(at::BlasBackend b) {
#ifdef _MSC_VER
TORCH_WARN_ONCE(
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
"It is not supported on Windows."
);
#else
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
if (b != at::BlasBackend::Cublas) {
@ -284,7 +278,6 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
);
}
blas_preferred_backend = b;
#endif
}
bool Context::allowFP16ReductionCuBLAS() const {

View File

@ -158,6 +158,159 @@ namespace {
Explicit registration for out-of-place ops
*****************************************/
#define AT_FORALL_LOWER_PRECISION_FP(_) \
_(_convolution, deprecated) \
_(_convolution) \
_(conv1d) \
_(conv2d) \
_(conv3d) \
_(conv_tbc) \
_(conv_transpose1d) \
_(conv_transpose2d, input) \
_(conv_transpose3d, input) \
_(convolution) \
_(prelu) \
_(addmm) \
_(addmv) \
_(addr) \
_(matmul) \
_(einsum) \
_(mm) \
_(mv) \
_(linalg_vecdot) \
_(linear) \
_(addbmm) \
_(baddbmm) \
_(bmm) \
_(chain_matmul) \
_(linalg_multi_dot) \
_(_thnn_fused_lstm_cell) \
_(_thnn_fused_gru_cell) \
_(lstm_cell) \
_(gru_cell) \
_(rnn_tanh_cell) \
_(rnn_relu_cell) \
_(_scaled_dot_product_flash_attention) \
_(scaled_dot_product_attention)
#define AT_FORALL_FP32(_) \
_(acos) \
_(asin) \
_(cosh) \
_(erfinv) \
_(exp) \
_(expm1) \
_(log) \
_(log10) \
_(log2) \
_(log1p) \
_(reciprocal) \
_(rsqrt) \
_(sinh) \
_(tan) \
_(pow, Tensor_Scalar) \
_(pow, Tensor_Tensor) \
_(pow, Scalar) \
_(softplus) \
_(layer_norm) \
_(native_layer_norm) \
_(group_norm) \
_(frobenius_norm, dim) \
_(nuclear_norm) \
_(nuclear_norm, dim) \
_(cosine_similarity) \
_(poisson_nll_loss) \
_(cosine_embedding_loss) \
_(nll_loss) \
_(nll_loss2d) \
_(hinge_embedding_loss) \
_(kl_div) \
_(l1_loss) \
_(smooth_l1_loss) \
_(huber_loss) \
_(mse_loss) \
_(margin_ranking_loss) \
_(multilabel_margin_loss) \
_(soft_margin_loss) \
_(triplet_margin_loss) \
_(multi_margin_loss) \
_(binary_cross_entropy_with_logits) \
_(dist) \
_(pdist) \
_(cdist) \
_(renorm) \
_(logsumexp) \
_(upsample_nearest1d) \
_(_upsample_nearest_exact1d) \
_(upsample_nearest2d) \
_(_upsample_nearest_exact2d) \
_(upsample_nearest3d) \
_(_upsample_nearest_exact3d) \
_(upsample_linear1d) \
_(upsample_bilinear2d) \
_(_upsample_bilinear2d_aa) \
_(upsample_trilinear3d) \
_(upsample_bicubic2d) \
_(_upsample_bicubic2d_aa)
#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
_(prod) \
_(prod, dim_int) \
_(prod, dim_Dimname) \
_(softmax, int) \
_(softmax, Dimname) \
_(log_softmax, int) \
_(log_softmax, Dimname) \
_(cumprod) \
_(cumprod, dimname) \
_(cumsum) \
_(cumsum, dimname) \
_(linalg_vector_norm) \
_(linalg_matrix_norm) \
_(linalg_matrix_norm, str_ord) \
_(sum) \
_(sum, dim_IntList) \
_(sum, dim_DimnameList)
#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
_(ADD_NS(norm), \
"norm.Scalar", \
Tensor(const Tensor&, const Scalar&), \
Tensor(const Tensor&, const c10::optional<Scalar>&, ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, IntArrayRef, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
IntArrayRef, \
bool, \
ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.names_ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, DimnameList, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
DimnameList, \
bool, \
ScalarType), \
fp32_append_dtype)
#define AT_FORALL_PROMOTE(_) \
_(addcdiv) \
_(addcmul) \
_(atan2) \
_(bilinear) \
_(cross) \
_(dot) \
_(grid_sampler) \
_(index_put) \
_(tensordot) \
_(scatter_add)
TORCH_LIBRARY_IMPL(_, Autocast, m) {
m.fallback(torch::CppFunction::makeFallthrough());
}

View File

@ -728,7 +728,7 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
// KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE
// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1
#define KERNEL_PRIVATEUSEONE(...) \
#define KERNEL_PRIVATEUSEONE(OP, ...) \
KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__)
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \
@ -744,158 +744,3 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions.
REGISTER_SIGNATURE, \
REDISPATCH_SIGNATURE, \
POLICY)
// Op lists for different policies.
// To make sure other backends can reuse the policy op list.
#define AT_FORALL_LOWER_PRECISION_FP(_) \
_(_convolution, deprecated) \
_(_convolution) \
_(conv1d) \
_(conv2d) \
_(conv3d) \
_(conv_tbc) \
_(conv_transpose1d) \
_(conv_transpose2d, input) \
_(conv_transpose3d, input) \
_(convolution) \
_(prelu) \
_(addmm) \
_(addmv) \
_(addr) \
_(matmul) \
_(einsum) \
_(mm) \
_(mv) \
_(linalg_vecdot) \
_(linear) \
_(addbmm) \
_(baddbmm) \
_(bmm) \
_(chain_matmul) \
_(linalg_multi_dot) \
_(_thnn_fused_lstm_cell) \
_(_thnn_fused_gru_cell) \
_(lstm_cell) \
_(gru_cell) \
_(rnn_tanh_cell) \
_(rnn_relu_cell) \
_(_scaled_dot_product_flash_attention) \
_(scaled_dot_product_attention)
#define AT_FORALL_FP32(_) \
_(acos) \
_(asin) \
_(cosh) \
_(erfinv) \
_(exp) \
_(expm1) \
_(log) \
_(log10) \
_(log2) \
_(log1p) \
_(reciprocal) \
_(rsqrt) \
_(sinh) \
_(tan) \
_(pow, Tensor_Scalar) \
_(pow, Tensor_Tensor) \
_(pow, Scalar) \
_(softplus) \
_(layer_norm) \
_(native_layer_norm) \
_(group_norm) \
_(frobenius_norm, dim) \
_(nuclear_norm) \
_(nuclear_norm, dim) \
_(cosine_similarity) \
_(poisson_nll_loss) \
_(cosine_embedding_loss) \
_(nll_loss) \
_(nll_loss2d) \
_(hinge_embedding_loss) \
_(kl_div) \
_(l1_loss) \
_(smooth_l1_loss) \
_(huber_loss) \
_(mse_loss) \
_(margin_ranking_loss) \
_(multilabel_margin_loss) \
_(soft_margin_loss) \
_(triplet_margin_loss) \
_(multi_margin_loss) \
_(binary_cross_entropy_with_logits) \
_(dist) \
_(pdist) \
_(cdist) \
_(renorm) \
_(logsumexp) \
_(upsample_nearest1d) \
_(_upsample_nearest_exact1d) \
_(upsample_nearest2d) \
_(_upsample_nearest_exact2d) \
_(upsample_nearest3d) \
_(_upsample_nearest_exact3d) \
_(upsample_linear1d) \
_(upsample_bilinear2d) \
_(_upsample_bilinear2d_aa) \
_(upsample_trilinear3d) \
_(upsample_bicubic2d) \
_(_upsample_bicubic2d_aa)
#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \
_(prod) \
_(prod, dim_int) \
_(prod, dim_Dimname) \
_(softmax, int) \
_(softmax, Dimname) \
_(log_softmax, int) \
_(log_softmax, Dimname) \
_(cumprod) \
_(cumprod, dimname) \
_(cumsum) \
_(cumsum, dimname) \
_(linalg_vector_norm) \
_(linalg_matrix_norm) \
_(linalg_matrix_norm, str_ord) \
_(sum) \
_(sum, dim_IntList) \
_(sum, dim_DimnameList)
#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \
_(ADD_NS(norm), \
"norm.Scalar", \
Tensor(const Tensor&, const Scalar&), \
Tensor(const Tensor&, const c10::optional<Scalar>&, ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, IntArrayRef, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
IntArrayRef, \
bool, \
ScalarType), \
fp32_append_dtype) \
_(ADD_NS(norm), \
"norm.names_ScalarOpt_dim", \
Tensor(const Tensor&, const c10::optional<Scalar>&, DimnameList, bool), \
Tensor( \
const Tensor&, \
const c10::optional<Scalar>&, \
DimnameList, \
bool, \
ScalarType), \
fp32_append_dtype)
#define AT_FORALL_PROMOTE(_) \
_(addcdiv) \
_(addcmul) \
_(atan2) \
_(bilinear) \
_(cross) \
_(dot) \
_(grid_sampler) \
_(index_put) \
_(tensordot) \
_(scatter_add)

View File

@ -9,7 +9,7 @@ namespace c10 {
// const reference (const T&); taking T by non-const reference
// will result in an error like:
//
// error: no type named 'type' in 'class std::invoke_result<foobar::__lambda, T>'
// error: no type named 'type' in 'class std::result_of<foobar::__lambda(T)>'
//
// No explicit template parameters are required.

View File

@ -227,7 +227,6 @@ namespace c10 {
_(aten, is_autocast_enabled) \
_(aten, is_autocast_cpu_enabled) \
_(aten, is_autocast_xla_enabled) \
_(aten, get_autocast_dtype) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \

View File

@ -1034,9 +1034,11 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
*/
template <typename T>
void addCallback(T callback, bool uses_future = true) {
#if __cpp_lib_is_invocable >= 201703
static_assert(
std::is_invocable_r<void, T, Future&>::value,
"The callback must have signature void(Future&)");
#endif
std::unique_lock<std::mutex> lock(mutex_);
if (completed()) {
@ -1055,13 +1057,14 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
template <typename T>
c10::intrusive_ptr<Future> then(T callback, TypePtr type) {
using IValueWithStorages = std::tuple<IValue, std::vector<WeakStorage>>;
#if __cpp_lib_is_invocable >= 201703
static_assert(
std::disjunction<
std::is_invocable_r<IValue, T, Future&>,
std::is_invocable_r<IValueWithStorages, T, Future&>>::value,
"The callback must have signature IValue(Future&) or "
"std::tuple<IValue, std::vector<Storage>>(Future&)");
#endif
auto childFut = createInstance(::std::move(type));
addCallback([childFut,
cb = std::move(callback)](Future& parentFut) mutable {
@ -1081,10 +1084,11 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
template <typename T>
c10::intrusive_ptr<Future> thenAsync(T callback, TypePtr type) {
#if __cpp_lib_is_invocable >= 201703
static_assert(
std::is_invocable_r<c10::intrusive_ptr<Future>, T, Future&>::value,
"The callback must have signature c10::intrusive_ptr<Future>(Future&)");
#endif
auto childFut = createInstance(std::move(type));
addCallback(
[childFut, cb = std::move(callback)](Future& parentFut) mutable {
@ -1161,9 +1165,11 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target {
// synchronize them with the value, and so on (if needed).
template<typename T>
void invokeCallback(T callback, bool uses_future) {
#if __cpp_lib_is_invocable >= 201703
static_assert(
std::is_invocable_r<void, T, Future&>::value,
"The callback must have signature void(Future&)");
#endif
// The synchronization performed below shouldn't be needed when the future
// is not used by the callback.
@ -2315,7 +2321,8 @@ IValue::IValue(c10::intrusive_ptr<T> custom_class) : tag(Tag::Object) {
} catch (const c10::Error&) {
throw c10::Error(
"Trying to instantiate a class that isn't a registered custom class: " +
std::string(c10::util::get_fully_qualified_type_name<T>()));
std::string(c10::util::get_fully_qualified_type_name<T>()),
"");
}
}();
auto ivalue_obj = c10::ivalue::Object::create(std::move(classType), /* numSlots */1);

View File

@ -126,44 +126,32 @@ struct VecConvert<int32_t, 1, uint8_t, 1> {
}
};
template <typename dst_t, typename src_t>
struct VecConvert<
dst_t,
1,
src_t,
1,
typename std::enable_if_t<
(is_reduced_floating_point_v<dst_t> && is_8bit_integer_v<src_t>) ||
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
1,
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 1>& src) {
dst_t,
1,
float,
1,
typename std::enable_if_t<
std::is_same_v<dst_t, unsigned char> || std::is_same_v<dst_t, signed char>,
void>> {
static inline VectorizedN<dst_t, 1> apply(
const VectorizedN<float, 1>& src) {
return convert_float_to_int8<dst_t>(src[0]);
}
};
template <typename src_t>
struct VecConvert<
float,
1,
src_t,
1,
typename std::enable_if_t<is_8bit_integer_v<src_t>,
void>> {
static inline VectorizedN<float, 1> apply(const VectorizedN<src_t, 1>& src) {
float,
1,
src_t,
1,
typename std::enable_if_t<
std::is_same_v<src_t, unsigned char> || std::is_same_v<src_t, signed char>,
void>> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<src_t, 1>& src) {
return convert_int8_to_float<src_t>(src[0]);
}
};

View File

@ -13,6 +13,8 @@
#include <ATen/cpu/vec/vec_base.h>
#include <c10/util/complex.h>
#define SLEEF_MEMORY_WORKAROUND
namespace at {
namespace vec {
@ -1146,20 +1148,32 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented<T>()>> {
}
Vectorized<T> sin() const {
#ifndef SLEEF_MEMORY_WORKAROUND
return mapSleef(Sleef_sinf4_u10, Sleef_sind2_u10);
#else
return mapOrdinary(std::sin);
#endif
}
Vectorized<T> sinh() const {
return mapSleef(Sleef_sinhf4_u10, Sleef_sinhd2_u10);
}
Vectorized<T> cos() const {
#ifndef SLEEF_MEMORY_WORKAROUND
return mapSleef(Sleef_cosf4_u10, Sleef_cosd2_u10);
#else
return mapOrdinary(std::cos);
#endif
}
Vectorized<T> cosh() const {
return mapSleef(Sleef_coshf4_u10, Sleef_coshd2_u10);
}
Vectorized<T> tan() const {
#ifndef SLEEF_MEMORY_WORKAROUND
return mapSleef(Sleef_tanf4_u10, Sleef_tand2_u10);
#else
return mapOrdinary(std::tan);
#endif
}
Vectorized<T> tanh() const {
return mapSleef(Sleef_tanhf4_u10, Sleef_tanhd2_u10);

View File

@ -117,44 +117,32 @@ struct VecConvert<int32_t, 1, uint8_t, 1> {
}
};
template <typename dst_t, typename src_t>
struct VecConvert<
dst_t,
1,
src_t,
1,
typename std::enable_if_t<
(is_reduced_floating_point_v<dst_t> && is_8bit_integer_v<src_t>) ||
(is_reduced_floating_point_v<src_t> && is_8bit_integer_v<dst_t>),
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<src_t, 1>& src) {
VectorizedN<float, 1> tmp_fp32 = VecConvert<float, 1, src_t, 1>::apply(src);
return VecConvert<dst_t, 1, float, 1>::apply(tmp_fp32);
}
};
template <typename dst_t>
struct VecConvert<
dst_t,
1,
float,
1,
typename std::enable_if_t<is_8bit_integer_v<dst_t>,
void>> {
static inline VectorizedN<dst_t, 1> apply(const VectorizedN<float, 1>& src) {
dst_t,
1,
float,
1,
typename std::enable_if_t<
std::is_same_v<dst_t, unsigned char> || std::is_same_v<dst_t, signed char>,
void>> {
static inline VectorizedN<dst_t, 1> apply(
const VectorizedN<float, 1>& src) {
return convert_float_to_int8<dst_t>(src[0]);
}
};
template <typename src_t>
struct VecConvert<
float,
1,
src_t,
1,
typename std::enable_if_t<is_8bit_integer_v<src_t>,
void>> {
static inline VectorizedN<float, 1> apply(const VectorizedN<src_t, 1>& src) {
float,
1,
src_t,
1,
typename std::enable_if_t<
std::is_same_v<src_t, unsigned char> || std::is_same_v<src_t, signed char>,
void>> {
static inline VectorizedN<float, 1> apply(
const VectorizedN<src_t, 1>& src) {
return convert_int8_to_float<src_t>(src[0]);
}
};

View File

@ -90,16 +90,6 @@ struct is_reduced_floating_point:
template <typename T>
constexpr bool is_reduced_floating_point_v = is_reduced_floating_point<T>::value;
template <typename T>
struct is_8bit_integer:
std::integral_constant<bool,
std::is_same_v<T, unsigned char> ||
std::is_same_v<T, signed char>> {
};
template <typename T>
constexpr bool is_8bit_integer_v = is_8bit_integer<T>::value;
template<size_t n> struct int_of_size;
#define DEFINE_INT_OF_SIZE(int_t) \

View File

@ -15,14 +15,6 @@
#include <ATen/cuda/Exceptions.h>
#include <c10/util/StringUtil.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/allclose.h>
#include <ATen/ops/from_blob.h>
#endif
namespace at::cuda::tunable {
enum class BlasOp {
@ -41,39 +33,6 @@ inline std::string BlasOpToString(BlasOp op) {
return "N";
}
namespace detail {
static bool NumericalCheck(ScalarType dtype, void* c, void* other_c, int64_t size) {
auto options = at::TensorOptions().dtype(dtype).device(at::kCUDA);
// comparison done as 1D tensor
at::Tensor ref = at::from_blob(c, {size}, options);
at::Tensor oth = at::from_blob(other_c, {size}, options);
at::Tensor ref_float = ref.to(at::kFloat);
at::Tensor oth_float = oth.to(at::kFloat);
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
double last_succeed_atol = 1;
double last_succeed_rtol = 1;
for (auto& atol : atols) {
for (auto& rtol : rtols) {
if (at::allclose(ref_float, oth_float, rtol, atol)) {
last_succeed_atol = atol;
last_succeed_rtol = rtol;
}
}
}
if (last_succeed_atol == 1) {
return false;
}
else {
TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
}
return true;
}
}
template <typename T>
struct GemmParams : OpParams {
std::string Signature() const override {
@ -98,8 +57,32 @@ struct GemmParams : OpParams {
}
TuningStatus NumericalCheck(GemmParams<T> *other) {
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, m*n) ? OK : FAIL;
auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType<T>::value).device(at::kCUDA);
// comparison done as 1D tensor
at::Tensor ref = at::from_blob(c, {m*n}, options);
at::Tensor oth = at::from_blob(other->c, {m*n}, options);
at::Tensor ref_float = ref.to(at::kFloat);
at::Tensor oth_float = oth.to(at::kFloat);
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
double last_succeed_atol = 1;
double last_succeed_rtol = 1;
for (auto& atol : atols) {
for (auto& rtol : rtols) {
if (at::allclose(ref_float, oth_float, rtol, atol)) {
last_succeed_atol = atol;
last_succeed_rtol = rtol;
}
}
}
if (last_succeed_atol == 1) {
return FAIL;
}
else {
TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
}
return OK;
}
char transa;
@ -141,8 +124,32 @@ struct GemmStridedBatchedParams : OpParams {
}
TuningStatus NumericalCheck(GemmStridedBatchedParams<T> *other) {
auto c_dtype = c10::CppTypeToScalarType<T>::value;
return detail::NumericalCheck(c_dtype, c, other->c, batch*stride_c) ? OK : FAIL;
auto options = at::TensorOptions().dtype(c10::CppTypeToScalarType<T>::value).device(at::kCUDA);
// comparison done as 1D tensor
at::Tensor ref = at::from_blob(c, {batch*stride_c}, options);
at::Tensor oth = at::from_blob(other->c, {batch*stride_c}, options);
at::Tensor ref_float = ref.to(at::kFloat);
at::Tensor oth_float = oth.to(at::kFloat);
std::vector<double> atols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
std::vector<double> rtols{1e-1, 1e-2, 1e-3, 1e-4, 1e-5};
double last_succeed_atol = 1;
double last_succeed_rtol = 1;
for (auto& atol : atols) {
for (auto& rtol : rtols) {
if (at::allclose(ref_float, oth_float, rtol, atol)) {
last_succeed_atol = atol;
last_succeed_rtol = rtol;
}
}
}
if (last_succeed_atol == 1) {
return FAIL;
}
else {
TUNABLE_LOG("├──verify numerics: atol=", last_succeed_atol, ", rtol=", last_succeed_rtol);
}
return OK;
}
char transa;
@ -164,54 +171,4 @@ struct GemmStridedBatchedParams : OpParams {
int64_t batch;
};
template <typename T>
struct ScaledGemmParams : OpParams {
std::string Signature() const override {
return c10::str(transa, transb, "_", m, "_", n, "_", k);
}
ScaledGemmParams* DeepCopy() const {
ScaledGemmParams* copy = new ScaledGemmParams;
*copy = *this;
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
size_t c_size = m * n * sizeof(T);
copy->c = c10::cuda::CUDACachingAllocator::raw_alloc(c_size);
AT_CUDA_CHECK(c10::cuda::CUDACachingAllocator::memcpyAsync(
copy->c, device, c, device, c_size, getCurrentCUDAStream(device), true));
return copy;
}
// only call on object returned by DeepCopy
void Delete() {
c10::cuda::CUDACachingAllocator::raw_delete(c);
}
TuningStatus NumericalCheck(ScaledGemmParams<T> *other) {
return detail::NumericalCheck(c_dtype, c, other->c, m*n) ? OK : FAIL;
}
char transa;
char transb;
int64_t m;
int64_t n;
int64_t k;
const void* a;
const void* a_scale_ptr;
int64_t lda;
ScalarType a_dtype;
const void* b;
const void* b_scale_ptr;
int64_t ldb;
ScalarType b_dtype;
const void* bias_ptr;
ScalarType bias_dtype;
void* c;
const void* c_scale_ptr;
int64_t ldc;
ScalarType c_dtype;
void* amax_ptr;
bool use_fast_accum;
};
} // namespace at::cuda::tunable

View File

@ -4,7 +4,6 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <ATen/cuda/tunable/GemmCommon.h>
#include <c10/cuda/CUDACachingAllocator.h>
@ -68,16 +67,6 @@ constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLASLT_R_64F;
}
template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<c10::Float8_e4m3fnuz>() {
return HIPBLASLT_R_8F_E4M3;
}
template <>
constexpr hipblasltDatatype_t HipBlasDataTypeFor<c10::Float8_e5m2fnuz>() {
return HIPBLASLT_R_8F_E5M3;
}
#define DATA_TYPE_R_32 HIPBLASLT_R_32F
#else
@ -105,16 +94,6 @@ constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLAS_R_64F;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<c10::Float8_e4m3fnuz>() {
return HIP_R_8F_E4M3_FNUZ;
}
template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<c10::Float8_e5m2fnuz>() {
return HIP_R_8F_E5M2_FNUZ;
}
#ifdef HIPBLAS_V2
#define DATA_TYPE_R_32 HIP_R_32F
#else
@ -123,8 +102,8 @@ constexpr hipblasDatatype_t HipBlasDataTypeFor<c10::Float8_e5m2fnuz>() {
#endif
template <typename T>
int GetBatchFromParams(const GemmParams<T>* params) {
template <typename T, typename ParamsT>
int GetBatchFromParams(const ParamsT* params) {
return 1;
}
@ -133,13 +112,8 @@ int GetBatchFromParams(const GemmStridedBatchedParams<T>* params) {
return params->batch;
}
template <typename T>
int GetBatchFromParams(const ScaledGemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideAFromParams(const GemmParams<T>* params) {
template <typename T, typename ParamsT>
int GetStrideAFromParams(const ParamsT* params) {
return 1;
}
@ -148,13 +122,8 @@ int GetStrideAFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_a;
}
template <typename T>
int GetStrideAFromParams(const ScaledGemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideBFromParams(const GemmParams<T>* params) {
template <typename T, typename ParamsT>
int GetStrideBFromParams(const ParamsT* params) {
return 1;
}
@ -163,13 +132,8 @@ int GetStrideBFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_b;
}
template <typename T>
int GetStrideBFromParams(const ScaledGemmParams<T>* params) {
return 1;
}
template <typename T>
int GetStrideCFromParams(const GemmParams<T>* params) {
template <typename T, typename ParamsT>
int GetStrideCFromParams(const ParamsT* params) {
return 1;
}
@ -178,116 +142,6 @@ int GetStrideCFromParams(const GemmStridedBatchedParams<T>* params) {
return params->stride_c;
}
template <typename T>
int GetStrideCFromParams(const ScaledGemmParams<T>* params) {
return 1;
}
template <typename T>
float GetAlphaFromParams(const GemmParams<T>* params) {
return params->alpha;
}
template <typename T>
float GetAlphaFromParams(const GemmStridedBatchedParams<T>* params) {
return params->alpha;
}
template <typename T>
float GetAlphaFromParams(const ScaledGemmParams<T>* params) {
return 1.0;
}
template <typename T>
float GetBetaFromParams(const GemmParams<T>* params) {
return params->beta;
}
template <typename T>
float GetBetaFromParams(const GemmStridedBatchedParams<T>* params) {
return params->beta;
}
template <typename T>
float GetBetaFromParams(const ScaledGemmParams<T>* params) {
return 0.0;
}
template <typename T>
const void* GetAScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetAScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetAScalePointerFromParams(const ScaledGemmParams<T>* params) {
return params->a_scale_ptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBScalePointerFromParams(const ScaledGemmParams<T>* params) {
return params->b_scale_ptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetDScalePointerFromParams(const ScaledGemmParams<T>* params) {
return params->c_scale_ptr;
}
template <typename T>
const void* GetBiasPointerFromParams(const GemmParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBiasPointerFromParams(const GemmStridedBatchedParams<T>* params) {
return nullptr;
}
template <typename T>
const void* GetBiasPointerFromParams(const ScaledGemmParams<T>* params) {
return params->bias_ptr;
}
template <typename T>
hipDataType GetBiasTypeFromParams(const GemmParams<T>* params) {
return HIP_R_32F;
}
template <typename T>
hipDataType GetBiasTypeFromParams(const GemmStridedBatchedParams<T>* params) {
return HIP_R_32F;
}
template <typename T>
hipDataType GetBiasTypeFromParams(const ScaledGemmParams<T>* params) {
return at::cuda::ScalarTypeToCudaDataType(params->bias_dtype);
}
static hipblasOperation_t _hipblasOpFromChar(char op) {
switch (op) {
case 'n':
@ -344,48 +198,7 @@ static size_t GetHipblasltWorkspaceSize() {
return workspace_size;
}
template <typename T, cublasStatus_t (*destructor)(T*)>
struct HipBlasLtDeleter {
void operator()(T* x) {
if (x != nullptr) {
TORCH_CUDABLAS_CHECK(destructor(x));
}
}
};
template <typename T, hipblasStatus_t (*destructor)(T*)>
class HipBlasLtDescriptor {
public:
T* descriptor() const {
return descriptor_.get();
}
T* descriptor() {
return descriptor_.get();
}
protected:
std::unique_ptr<T, HipBlasLtDeleter<T, destructor>> descriptor_;
};
class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor<
hipblasLtMatmulDescOpaque_t,
&hipblasLtMatmulDescDestroy> {
public:
HipBlasLtMatmulDescriptor(
hipblasComputeType_t compute_type,
hipDataType scale_type) {
hipblasLtMatmulDesc_t raw_descriptor = nullptr;
TORCH_HIPBLASLT_CHECK(
hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type));
descriptor_.reset(raw_descriptor);
}
template <typename T>
inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) {
TORCH_HIPBLASLT_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(T)));
}
};
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
template <typename T, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
class HipblasltGemmOp : public Callable<ParamsT> {
public:
HipblasltGemmOp(hipblasLtMatmulAlgo_t algo) : algo_{algo} {}
@ -393,38 +206,37 @@ class HipblasltGemmOp : public Callable<ParamsT> {
TuningStatus Call(const ParamsT* params) override {
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
auto a_datatype = HipBlasDataTypeFor<AT>();
auto b_datatype = HipBlasDataTypeFor<BT>();
auto in_out_datatype = HipBlasDataTypeFor<CT>();
auto in_out_datatype = HipBlasDataTypeFor<T>();
auto opa = _hipblasOpFromChar(params->transa);
auto opb = _hipblasOpFromChar(params->transb);
TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen");
float alpha = GetAlphaFromParams<CT>(params);
float beta = GetBetaFromParams<CT>(params);
float alpha = static_cast<float>(params->alpha);
float beta = static_cast<float>(params->beta);
hipblasLtMatrixLayout_t mat_a, mat_b, mat_c;
hipblasLtMatmulDesc_t matmul;
if (opa == HIPBLAS_OP_N) {
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->m, params->k, params->lda));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->m, params->k, params->lda));
}
else {
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, a_datatype, params->k, params->m, params->lda));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, params->k, params->m, params->lda));
}
if (opb == HIPBLAS_OP_N) {
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->k, params->n, params->ldb));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->k, params->n, params->ldb));
}
else {
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, b_datatype, params->n, params->k, params->ldb));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, params->n, params->k, params->ldb));
}
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, params->m, params->n, params->ldc));
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescCreate(&matmul, COMPUTE_TYPE_32, DATA_TYPE_R_32));
// specific to batched gemmm
int batch = GetBatchFromParams<CT>(params);
int batch = GetBatchFromParams<T>(params);
if (batch > 1) {
int64_t stride_a = GetStrideAFromParams<CT>(params);
int64_t stride_b = GetStrideBFromParams<CT>(params);
int64_t stride_c = GetStrideCFromParams<CT>(params);
int64_t stride_a = GetStrideAFromParams<T>(params);
int64_t stride_b = GetStrideBFromParams<T>(params);
int64_t stride_c = GetStrideCFromParams<T>(params);
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
mat_a, HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch)));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutSetAttribute(
@ -439,27 +251,10 @@ class HipblasltGemmOp : public Callable<ParamsT> {
mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c)));
}
HipBlasLtMatmulDescriptor matmul(COMPUTE_TYPE_32, DATA_TYPE_R_32);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb);
// specific to scaled gemm
const void* mat1_scale_ptr = GetAScalePointerFromParams<CT>(params);
const void* mat2_scale_ptr = GetBScalePointerFromParams<CT>(params);
const void* result_scale_ptr = GetDScalePointerFromParams<CT>(params);
if (mat1_scale_ptr && mat2_scale_ptr && result_scale_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat1_scale_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat2_scale_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
const void* bias_ptr = GetBiasPointerFromParams<CT>(params);
auto bias_datatype = GetBiasTypeFromParams<CT>(params);
if (bias_ptr) {
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS);
matmul.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, bias_datatype);
}
}
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &opa, sizeof(int32_t)));
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &opb, sizeof(int32_t)));
size_t workspace_size = GetHipblasltWorkspaceSize();
@ -467,7 +262,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
size_t ret_workspace_size = 0;
auto status = hipblaslt_ext::matmulIsAlgoSupported(op_handle,
matmul.descriptor(),
matmul,
&alpha,
mat_a,
mat_b,
@ -494,7 +289,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
}
TORCH_HIPBLASLT_CHECK(hipblasLtMatmul(op_handle,
matmul.descriptor(),
matmul,
&alpha,
params->a,
mat_a,
@ -510,7 +305,7 @@ class HipblasltGemmOp : public Callable<ParamsT> {
workspace_size,
at::cuda::getCurrentCUDAStream()));
//TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
TORCH_HIPBLASLT_CHECK(hipblasLtMatmulDescDestroy(matmul));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_a));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_b));
TORCH_HIPBLASLT_CHECK(hipblasLtMatrixLayoutDestroy(mat_c));
@ -524,13 +319,11 @@ class HipblasltGemmOp : public Callable<ParamsT> {
hipblasLtMatmulAlgo_t algo_;
};
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
template <typename T, BlasOp ALayout, BlasOp BLayout, typename ParamsT>
auto GetHipBlasLtTypeStringAndOps() {
hipblasOperation_t transa_outer = MapLayoutToHipBlasLt(ALayout);
hipblasOperation_t transb_outer = MapLayoutToHipBlasLt(BLayout);
auto a_datatype = HipBlasDataTypeFor<AT>();
auto b_datatype = HipBlasDataTypeFor<BT>();
auto in_out_datatype = HipBlasDataTypeFor<CT>();
auto in_out_datatype = HipBlasDataTypeFor<T>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
hipblasLtHandle_t handle;
@ -539,8 +332,8 @@ auto GetHipBlasLtTypeStringAndOps() {
hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
transa_outer,
transb_outer,
a_datatype,
b_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
COMPUTE_TYPE_32,
@ -559,7 +352,7 @@ auto GetHipBlasLtTypeStringAndOps() {
for (int i = 0; i < returned_algo_count; i++) {
auto algo = heuristic_result[i].algo;
int algo_index = GETINDEXFROMALGO(algo);
auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
auto callable = std::make_unique<HipblasltGemmOp<T, ALayout, BLayout, ParamsT>>(algo);
std::string type_string = c10::str(
"Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
ret.emplace_back(type_string, std::move(callable));
@ -570,17 +363,12 @@ auto GetHipBlasLtTypeStringAndOps() {
template <typename T, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtGemmTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmParams<T>>();
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmParams<T>>();
}
template <typename T, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtGemmStridedBatchedTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<T, T, T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
}
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
auto GetHipBlasLtScaledGemmTypeStringAndOps() {
return GetHipBlasLtTypeStringAndOps<AT, BT, CT, ALayout, BLayout, ScaledGemmParams<CT>>();
return GetHipBlasLtTypeStringAndOps<T, ALayout, BLayout, GemmStridedBatchedParams<T>>();
}
#undef TORCH_HIPBLASLT_CHECK

View File

@ -19,10 +19,6 @@
#include <ATen/cuda/tunable/StreamTimer.h>
#include <ATen/cuda/tunable/TunableOp.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/StringUtil.h>
#ifdef USE_ROCM
@ -68,112 +64,62 @@ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
};
template <typename T>
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
public:
TuningStatus Call(const ScaledGemmParams<T>* params) override {
at::cuda::blas::scaled_gemm(
params->transa,
params->transb,
params->m,
params->n,
params->k,
params->a,
params->a_scale_ptr,
params->lda,
params->a_dtype,
params->b,
params->b_scale_ptr,
params->ldb,
params->b_dtype,
params->bias_ptr,
params->bias_dtype,
params->c,
params->c_scale_ptr,
params->ldc,
params->c_dtype,
params->amax_ptr,
params->use_fast_accum);
return OK;
}
};
template <typename T>
inline bool IsZero(T v) {
bool IsZero(T v) {
return v == 0.0f;
}
template <>
inline bool IsZero(BFloat16 v) {
bool IsZero(BFloat16 v) {
return v.x == 0;
}
template <>
inline bool IsZero(Half v) {
bool IsZero(Half v) {
return float(v) == 0.0f;
}
template <>
inline bool IsZero(c10::complex<double> v) {
bool IsZero(c10::complex<double> v) {
return v == 0.0;
}
template <>
inline bool IsZero(c10::complex<float> v) {
bool IsZero(c10::complex<float> v) {
return v == 0.0f;
}
template <typename T>
inline std::string TypeName(T v) {
std::string TypeName(T v) {
return "unknown";
}
template <>
inline std::string TypeName(float v) {
std::string TypeName(float v) {
return "float";
}
template <>
inline std::string TypeName(double v) {
std::string TypeName(double v) {
return "double";
}
template <>
inline std::string TypeName(BFloat16 v) {
std::string TypeName(BFloat16 v) {
return "BFloat16";
}
template <>
inline std::string TypeName(Half v) {
std::string TypeName(Half v) {
return "Half";
}
template <>
inline std::string TypeName(Float8_e4m3fn v) {
return "Float8_e4m3fn";
}
template <>
inline std::string TypeName(Float8_e5m2 v) {
return "Float8_e5m2";
}
template <>
inline std::string TypeName(Float8_e4m3fnuz v) {
return "Float8_e4m3fnuz";
}
template <>
inline std::string TypeName(Float8_e5m2fnuz v) {
return "Float8_e5m2fnuz";
}
template <>
inline std::string TypeName(c10::complex<double> v) {
std::string TypeName(c10::complex<double> v) {
return "c10::complex<double>";
}
template <>
inline std::string TypeName(c10::complex<float> v) {
std::string TypeName(c10::complex<float> v) {
return "c10::complex<float>";
}
@ -326,42 +272,6 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
};
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
public:
ScaledGemmTunableOp() {
this->RegisterOp(std::string("Default"), std::make_unique<DefaultScaledGemmOp<CT>>());
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
#if defined(USE_ROCM) && ROCM_VERSION >= 50700
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
std::string hipblaslt_version = c10::str(
XSTRINGIFY(HIPBLASLT_VERSION_MAJOR), ".",
XSTRINGIFY(HIPBLASLT_VERSION_MINOR), ".",
XSTRINGIFY(HIPBLASLT_VERSION_PATCH), "-",
XSTRINGIFY(HIPBLASLT_VERSION_TWEAK));
getTuningContext()->GetTuningResultsValidator().RegisterValidator(
"HIPBLASLT_VERSION",
[hipblaslt_version]() { return hipblaslt_version; },
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
}
#endif
}
std::string Signature() override {
return c10::str("ScaledGemmTunableOp",
"_", TypeName<AT>(AT{}),
"_", TypeName<BT>(BT{}),
"_", TypeName<CT>(CT{}),
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
#undef XSTRINGIFY
#undef STRINGIFY

View File

@ -81,11 +81,6 @@ static Tensor unsafeMakeTensorWrapper(
auto result = at::detail::make_tensor<TensorWrapper>(
key_set, tensor, level, life_handle, is_immutable);
TORCH_INTERNAL_ASSERT(result.key_set().has(DispatchKey::FuncTorchGradWrapper));
if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
result.unsafeGetTensorImpl()->set_wrapped_number(true);
}
return result;
}

View File

@ -299,12 +299,6 @@ public:
void StartTrace(const std::string& mode, bool waitUntilCompleted);
void StopTrace();
// Abstractions for GPU trace capturing
bool isCaptureEnabled() const;
bool isCapturing() const;
void startCapture(const std::string& name, MPSStream* stream = nullptr);
void stopCapture(MPSStream* stream = nullptr);
// convenience functions to indicate whether signpost tracing or
// logging are enabled for the SignpostTypes
bool isOperationProfilingEnabled() const {
@ -362,9 +356,6 @@ public:
// a short list that contains copy stats
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
mutable MTLCaptureManager *captureManager = nil;
unsigned captureCount = 0;
void initialize();
void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,

View File

@ -765,41 +765,6 @@ void MPSProfiler::handleIntSignal(int signal) {
struct sigaction MPSProfiler::currentSigint {};
struct sigaction MPSProfiler::previousSigint {};
bool MPSProfiler::isCapturing() const {
return [captureManager isCapturing];
}
bool MPSProfiler::isCaptureEnabled() const {
if (captureManager == nil) {
captureManager = [MTLCaptureManager sharedCaptureManager];
}
static bool isEnabled = [this]() {
return [captureManager supportsDestination:MTLCaptureDestinationGPUTraceDocument];
}();
return isEnabled;
}
void MPSProfiler::startCapture(const std::string& name, MPSStream* stream) {
if (captureManager == nil) {
captureManager = [MTLCaptureManager sharedCaptureManager];
}
NSError* err = nil;
NSString* fname = [NSString stringWithFormat:@"%04d-%s.gputrace", captureCount++, name.c_str()];
MTLCaptureDescriptor* captureDescriptor = [MTLCaptureDescriptor new];
captureDescriptor.captureObject = stream ? (id)stream->commandQueue() : (id)MPSDevice::getInstance()->device();
captureDescriptor.destination = MTLCaptureDestinationGPUTraceDocument;
captureDescriptor.outputURL = [NSURL fileURLWithPath:fname];
auto rc = [captureManager startCaptureWithDescriptor:captureDescriptor error:&err];
TORCH_CHECK(rc, "Failed to start capture of ", [fname UTF8String], " error ", [[err description] UTF8String]);
}
void MPSProfiler::stopCapture(MPSStream* stream) {
if (stream) {
stream->synchronize(SyncType::COMMIT);
}
[captureManager stopCapture];
}
} // namespace Profiler
Profiler::MPSProfiler& getMPSProfiler() {

View File

@ -22,7 +22,7 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
_compilationDescriptor = [MPSGraphCompilationDescriptor new];
// disable commitAndContinue if Signpost tracing is enabled
if (getMPSProfiler().isSignpostTracingEnabled() || getMPSProfiler().isCaptureEnabled()) {
if (getMPSProfiler().isSignpostTracingEnabled()) {
_enableCommitAndContinue = false;
}
_executionDescriptor.enableCommitAndContinue = _enableCommitAndContinue;

View File

@ -317,12 +317,6 @@ Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_siz
// in this case, adaptive pooling is just computing mean over hw
// dimensions, which can be done more efficiently
Tensor out = input.mean({-1, -2, -3}, /* keepdim = */ true);
if (input.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d) {
// assert ndim == 5, since ndim = 4 doesn't give channels_last
const auto n = input.sym_size(0);
const auto c = input.sym_size(1);
out.as_strided__symint({n, c, 1, 1, 1}, {c, 1, c, c, c});
}
return out;
} else {
return _adaptive_avg_pool3d_symint(input, output_size);

View File

@ -8,25 +8,15 @@
namespace at::native {
using adaptive_avg_pooling2d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
using adaptive_avg_pooling2d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
DECLARE_DISPATCH(adaptive_avg_pooling2d_fn, adaptive_avg_pool2d_kernel);
DECLARE_DISPATCH(adaptive_avg_pooling2d_backward_fn, adaptive_avg_pool2d_backward_kernel);
using adaptive_avg_pooling_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
using adaptive_avg_pooling_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
DECLARE_DISPATCH(adaptive_avg_pooling_fn, adaptive_avg_pool2d_kernel);
DECLARE_DISPATCH(adaptive_avg_pooling_backward_fn, adaptive_avg_pool2d_backward_kernel);
using adaptive_max_pooling2d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
using adaptive_max_pooling2d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
DECLARE_DISPATCH(adaptive_max_pooling2d_fn, adaptive_max_pool2d_kernel);
DECLARE_DISPATCH(adaptive_max_pooling2d_backward_fn, adaptive_max_pool2d_backward_kernel);
using adaptive_avg_pooling3d_fn = void(*)(Tensor& output, const Tensor& input, IntArrayRef output_size);
using adaptive_avg_pooling3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output);
DECLARE_DISPATCH(adaptive_avg_pooling3d_fn, adaptive_avg_pool3d_kernel);
DECLARE_DISPATCH(adaptive_avg_pooling3d_backward_fn, adaptive_avg_pool3d_backward_kernel);
using adaptive_max_pooling3d_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
using adaptive_max_pooling3d_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
DECLARE_DISPATCH(adaptive_max_pooling3d_fn, adaptive_max_pool3d_kernel);
DECLARE_DISPATCH(adaptive_max_pooling3d_backward_fn, adaptive_max_pool3d_backward_kernel);
using adaptive_max_pooling_fn = void(*)(const Tensor& output, const Tensor& indices, const Tensor& input, IntArrayRef output_size);
using adaptive_max_pooling_backward_fn = void(*)(const Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);
DECLARE_DISPATCH(adaptive_max_pooling_fn, adaptive_max_pool2d_kernel);
DECLARE_DISPATCH(adaptive_max_pooling_backward_fn, adaptive_max_pool2d_backward_kernel);
static inline int64_t start_index(int64_t a, int64_t b, int64_t c) {
return (a / b) * c + ((a % b) * c) / b;

View File

@ -10,19 +10,8 @@
#include <cstdlib>
#include <cstring>
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
#include <sys/auxv.h>
#endif
namespace at::native {
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
static inline bool cpu_has_vxe()
{
return (getauxval(AT_HWCAP) & HWCAP_S390_VXE);
}
#endif
static CPUCapability compute_cpu_capability() {
auto envar = std::getenv("ATEN_CPU_CAPABILITY");
if (envar) {
@ -71,16 +60,10 @@ static CPUCapability compute_cpu_capability() {
#endif
}
#endif
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
// vxe is needed for fp32 vector instructions
if (cpu_has_vxe()) {
return CPUCapability::ZVECTOR;
}
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
return CPUCapability::VSX;
#elif HAVE_ZVECTOR_CPU_DEFINITION
return CPUCapability::ZVECTOR;
#else
return CPUCapability::DEFAULT;
#endif

View File

@ -2839,16 +2839,10 @@ TORCH_IMPL_FUNC(linalg_vector_norm_out)(const Tensor& self, const Scalar& scalar
}
if (is_reduce_over_1D_vector) {
Tensor self_;
if (opt_dtype.has_value()) {
self_ = self.to(*opt_dtype);
} else {
self_ = self;
}
if (ord != 0.0) {
keepdim ? at::abs_outf(self_, const_cast<Tensor&>(result)) : at::abs_outf(self_.squeeze(reduce_dim), const_cast<Tensor&>(result));
keepdim ? at::abs_outf(self, const_cast<Tensor&>(result)) : at::abs_outf(self.squeeze(reduce_dim), const_cast<Tensor&>(result));
} else {
keepdim ? at::ne_outf(self_, 0, const_cast<Tensor&>(result)) : at::ne_outf(self_.squeeze(reduce_dim), 0, const_cast<Tensor&>(result));
keepdim ? at::ne_outf(self, 0, const_cast<Tensor&>(result)) : at::ne_outf(self.squeeze(reduce_dim), 0, const_cast<Tensor&>(result));
}
return;
}

View File

@ -26,19 +26,6 @@ using avg_pool2d_backward_fn = void(*)(const Tensor& output, const Tensor& input
DECLARE_DISPATCH(avg_pool2d_fn, avg_pool2d_kernel);
DECLARE_DISPATCH(avg_pool2d_backward_fn, avg_pool2d_backward_kernel);
// averge pooling has same signature for forward and backward
using avg_pool3d_fn = void(*)(const Tensor& output, const Tensor& input,
int64_t kW, int64_t kH, int64_t kD, int64_t dW, int64_t dH, int64_t dD,
int64_t padW, int64_t padH, int64_t padD, bool count_include_pad,
c10::optional<int64_t> divisor_override);
using avg_pool3d_backward_fn = void(*)(const Tensor& output, const Tensor& input,
int kW, int kH, int kD, int dW, int dH, int dD,
int padW, int padH, int padD, bool count_include_pad,
c10::optional<int64_t> divisor_override);
DECLARE_DISPATCH(avg_pool3d_fn, avg_pool3d_kernel);
DECLARE_DISPATCH(avg_pool3d_backward_fn, avg_pool3d_backward_kernel);
using max_pool3d_fn = void(*)(Tensor& output, Tensor& indices, const Tensor& input,
int kW, int kH, int kD, int dW, int dH, int dD, int pW, int pH, int pD, int dilationW, int dilationH, int dilationD);
using max_pool3d_backward_fn = void(*)(Tensor& grad_input, const Tensor& grad_output, const Tensor& indices);

View File

@ -254,50 +254,13 @@ Tensor _to_copy(
// TODO: Use the dispatcher for this.
// Currently there are unenumerated extensibility issues preventing this.
if (self.layout() == kSparse) {
TORCH_CHECK(
memory_format == MemoryFormat::Preserve,
"to(options): COO only supports memory format Preserve, but got ", memory_format,
" instead.");
if (options.device().is_meta()) {
return zeros_like(self, options);
}
auto indices = self._indices();
const auto new_indices = at::native::to(
indices,
indices.scalar_type(),
c10::kStrided,
device,
pin_memory,
non_blocking,
true, // force copy since we are in _to_copy
memory_format);
const auto new_values = at::native::to(
self._values(),
dtype,
c10::kStrided,
device,
pin_memory,
non_blocking,
true, // force copy since we are in _to_copy
memory_format);
return at::_sparse_coo_tensor_unsafe(
new_indices,
new_values,
self.sizes(),
options, self.is_coalesced());
} else if (at::sparse_csr::is_sparse_compressed(self)) {
if (at::sparse_csr::is_sparse_compressed(self)) {
TORCH_CHECK(
memory_format == MemoryFormat::Preserve,
"to(options): ", at::sparse_csr::layoutToString(self.layout()),
" only supports memory format Preserve, but got ", memory_format,
" instead.");
if (options.device().is_meta()) {
return zeros_like(self, options);
}
auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(self);
const auto new_values = at::native::to(

View File

@ -421,19 +421,9 @@ Tensor& set_storage_meta__symint(Tensor& result, Storage storage, c10::SymInt st
// it. TODO: Actually this might not quite be correct if we use special
// pointers to track whether or not fake cuda tensors are pinned or not
const auto itemsize = result.dtype().itemsize();
c10::SymInt new_size_bytes = at::detail::computeStorageNbytes(
c10::SymInt size_bytes = at::detail::computeStorageNbytes(
size, stride, itemsize, std::move(storage_offset));
// TODO: When there are unbacked SymInts, we unconditionally skip the
// setter. This is technically wrong, but we cannot conveniently test
// the real condition in many cases, because a lot of people are using
// set_ just to swizzle metadata on a tensor, they didn't actually want
// to see if they need to resize the storage.
//
// The old behavior was to unconditionally set_nbytes, but I think not
// setting it is more safe.
if (new_size_bytes.has_hint() && storage.sym_nbytes().has_hint() && TORCH_GUARD_SIZE_OBLIVIOUS(new_size_bytes.sym_gt(storage.sym_nbytes()))) {
storage.set_nbytes(std::move(new_size_bytes));
}
storage.set_nbytes(std::move(size_bytes));
}
return result;
}
@ -4082,13 +4072,11 @@ void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList o
}
}
int64_t sparse_dim_default(const Tensor& self) {
TORCH_CHECK(self.layout() == kStrided, "sparse_dim expected sparse or strided tensor layout but got ", self.layout());
int64_t sparse_dim_strided(const at::Tensor& self) {
return 0;
}
int64_t dense_dim_default(const Tensor& self) {
TORCH_CHECK(self.layout() == kStrided, "dense_dim expected sparse or strided tensor layout but got ", self.layout());
int64_t dense_dim_strided(const at::Tensor& self) {
return self.dim();
}

View File

@ -15,7 +15,7 @@ namespace at::native {
namespace {
template <typename scalar_t, typename accscalar_t>
void cpu_adaptive_avg_pool2d(
void cpu_adaptive_avg_pool(
Tensor& output_,
const Tensor& input_,
IntArrayRef output_size) {
@ -69,7 +69,7 @@ void cpu_adaptive_avg_pool2d(
template <typename scalar_t>
typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool2d_channels_last(
cpu_adaptive_avg_pool_channels_last(
Tensor& output_,
const Tensor& input_,
IntArrayRef output_size) {
@ -156,7 +156,7 @@ cpu_adaptive_avg_pool2d_channels_last(
template <typename scalar_t>
typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool2d_channels_last(
cpu_adaptive_avg_pool_channels_last(
Tensor& output_,
const Tensor& input_,
IntArrayRef output_size) {
@ -255,7 +255,7 @@ cpu_adaptive_avg_pool2d_channels_last(
}
template <typename scalar_t>
void cpu_adaptive_avg_pool2d_backward(
void cpu_adaptive_avg_pool_backward(
Tensor& grad_input_,
const Tensor& grad_output_) {
auto grad_output = grad_output_.contiguous();
@ -305,7 +305,7 @@ void cpu_adaptive_avg_pool2d_backward(
}
template <typename scalar_t>
void cpu_adaptive_avg_pool2d_backward_channels_last(
void cpu_adaptive_avg_pool_backward_channels_last(
Tensor& grad_input_,
const Tensor& grad_output_) {
auto memory_format = at::MemoryFormat::ChannelsLast;
@ -373,13 +373,13 @@ void adaptive_avg_pool2d_kernel_impl(
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool2d", [&] {
using param_t = at::opmath_type<scalar_t>;
cpu_adaptive_avg_pool2d<scalar_t, /*accscalar_t*/param_t>(output, input, output_size);
cpu_adaptive_avg_pool<scalar_t, /*accscalar_t*/param_t>(output, input, output_size);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool2d_channels_last", [&]{
cpu_adaptive_avg_pool2d_channels_last<scalar_t>(output, input, output_size);
cpu_adaptive_avg_pool_channels_last<scalar_t>(output, input, output_size);
});
break;
}
@ -394,458 +394,13 @@ void adapative_avg_pool2d_backward_kernel_impl(
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool2d_backward", [&] {
cpu_adaptive_avg_pool2d_backward<scalar_t>(grad_input, grad_output);
cpu_adaptive_avg_pool_backward<scalar_t>(grad_input, grad_output);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool2d_backward_channels_last", [&]{
cpu_adaptive_avg_pool2d_backward_channels_last<scalar_t>(grad_input, grad_output);
});
break;
}
default:
TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
template <typename scalar_t, typename accscalar_t>
void cpu_adaptive_avg_pool3d(
Tensor& output_,
const Tensor& input_,
IntArrayRef output_size) {
auto input = input_.contiguous();
auto output = output_.contiguous();
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t ndim = input.ndimension();
// treat batch size and channels as one dimension
int64_t channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
int64_t input_depth = input.size(-3);
int64_t input_height = input.size(-2);
int64_t input_width = input.size(-1);
int64_t output_depth = output_size[0];
int64_t output_height = output_size[1];
int64_t output_width = output_size[2];
// parallel on dim of N, C
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width;
scalar_t* output_ptr = output_data + c * output_depth * output_height * output_width;
for (const auto od : c10::irange(output_depth)) {
int64_t id0 = start_index(od, output_depth, input_depth);
int64_t id1 = end_index(od, output_depth, input_depth);
int64_t kd = id1 - id0;
for (const auto oh : c10::irange(output_height)) {
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
int64_t kh = ih1 - ih0;
for (const auto ow : c10::irange(output_width)) {
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
int64_t kw = iw1 - iw0;
// compute local average
accscalar_t sum = 0;
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
sum += accscalar_t(input_ptr[id * input_height * input_width + ih * input_width + iw]);
}
}
}
output_ptr[od * output_height * output_width + oh * output_width + ow] = scalar_t(sum / kd / kh / kw);
}
}
}
}
});
if (!output_.is_contiguous()) {
output_.copy_(output);
}
}
template <typename scalar_t>
typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool3d_channels_last(
Tensor& output_,
const Tensor& input_,
IntArrayRef output_size) {
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_depth = input.size(2);
int64_t input_height = input.size(3);
int64_t input_width = input.size(4);
int64_t output_depth = output_size[0];
int64_t output_height = output_size[1];
int64_t output_width = output_size[2];
using Vec = vec::Vectorized<scalar_t>;
// parallel on dim N, H, W
at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t od = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
for (const auto i : c10::irange(begin, end)) {
int64_t id0 = start_index(od, output_depth, input_depth);
int64_t id1 = end_index(od, output_depth, input_depth);
int64_t kd = id1 - id0;
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
int64_t kh = ih1 - ih0;
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
int64_t kw = iw1 - iw0;
scalar_t* out = output_data + i * channels;
int64_t size = channels;
// Note: For oridinary usage scenario, each out lane should
// fit in L1 cache; otherwise consider block dim C.
// Pass I: zero the out lane
int64_t d1 = 0;
for (; d1 < size - (size % Vec::size()); d1 += Vec::size()) {
Vec out_vec = Vec(scalar_t(0));
out_vec.store(out + d1);
}
for (; d1 < size; d1++) {
out[d1] = scalar_t(0);
}
// Pass II: compute local sum
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
scalar_t* in = input_data + n * input_depth * input_height * input_width * channels +
id * input_height * input_width * channels + ih * input_width * channels + iw * channels;
int64_t d2 = 0;
for (; d2 < size - (size % Vec::size()); d2 += Vec::size()) {
Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2);
out_vec.store(out + d2);
}
for (; d2 < size; d2++) {
out[d2] += in[d2];
}
}
}
}
// Pass III: compute local average
int64_t d3 = 0;
for (; d3 < size - (size % Vec::size()); d3 += Vec::size()) {
Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(kd * kh * kw));
out_vec.store(out + d3);
}
for (; d3 < size; d3++) {
out[d3] = out[d3] / kd / kh / kw;
}
// move on to next output index
data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous(memory_format)) {
output_.copy_(output);
}
}
template <typename scalar_t>
typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_avg_pool3d_channels_last(
Tensor& output_,
const Tensor& input_,
IntArrayRef output_size) {
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_depth = input.size(2);
int64_t input_height = input.size(3);
int64_t input_width = input.size(4);
int64_t output_depth = output_size[0];
int64_t output_height = output_size[1];
int64_t output_width = output_size[2];
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
// parallel on dim N,D, H, W
at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t oh = 0;
int64_t ow = 0;
int64_t od = 0;
data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
// temp buffer for sum, use float as accumulation type
// can't reuse output buffer to store sum since it is BFloat16/Half
auto sum_arr = std::make_unique<float []>(channels);
float* sum = sum_arr.get();
for (const auto i : c10::irange(begin, end)) {
int64_t id0 = start_index(od, output_depth, input_depth);
int64_t id1 = end_index(od, output_depth, input_depth);
int64_t kd = id1 - id0;
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
int64_t kh = ih1 - ih0;
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
int64_t kw = iw1 - iw0;
scalar_t* out = output_data + i * channels;
int64_t size = channels;
// Pass I: zero the out lane
int64_t d1 = 0;
for (; d1 < size - (size % fVec::size()); d1 += fVec::size()) {
fVec sum_fvec = fVec(float(0));
sum_fvec.store(sum + d1);
}
for (; d1 < size; d1++) {
sum[d1] = float(0);
}
// Pass II: compute local sum
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
scalar_t* in = input_data + n * input_depth * input_height * input_width * channels +
id * input_height * input_width * channels +
ih * input_width * channels + iw * channels;
int64_t d2 = 0;
for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) {
bVec data_bvec = bVec::loadu(in + d2);
fVec data_fvec0, data_fvec1;
std::tie(data_fvec0, data_fvec1) = convert_to_float<scalar_t>(data_bvec);
fVec sum_fvec0 = fVec::loadu(sum + d2) + data_fvec0;
fVec sum_fvec1 = fVec::loadu(sum + d2 + fVec::size()) + data_fvec1;
sum_fvec0.store(sum + d2);
sum_fvec1.store(sum + d2 + fVec::size());
}
for (; d2 < size; d2++) {
sum[d2] += float(in[d2]);
}
}
}
}
// Pass III: compute local average
int64_t d3 = 0;
for (; d3 < size - (size % bVec::size()); d3 += bVec::size()) {
fVec out_fvec0 = fVec::loadu(sum + d3) / fVec(float(kd * kh * kw));
fVec out_fvec1 = fVec::loadu(sum + d3 + fVec::size()) / fVec(float(kd * kh * kw));
bVec out_bvec = convert_from_float<scalar_t>(out_fvec0, out_fvec1);
out_bvec.store(out + d3);
}
for (; d3 < size; d3++) {
out[d3] = scalar_t(sum[d3] / kd / kh / kw);
}
// move on to next output index
data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous(memory_format)) {
output_.copy_(output);
}
}
template <typename scalar_t>
void cpu_adaptive_avg_pool3d_backward(
Tensor& grad_input_,
const Tensor& grad_output_) {
auto grad_output = grad_output_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
int64_t ndim = grad_output.ndimension();
// treat batch size and channels as one dimension
int64_t channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
int64_t input_depth = grad_input.size(-3);
int64_t input_height = grad_input.size(-2);
int64_t input_width = grad_input.size(-1);
int64_t output_depth = grad_output.size(-3);
int64_t output_height = grad_output.size(-2);
int64_t output_width = grad_output.size(-1);
// parallel on dim of N, C
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
scalar_t* grad_output_ptr = grad_output_data + c * output_depth * output_height * output_width;
for (const auto od : c10::irange(output_depth)) {
int64_t id0 = start_index(od, output_depth, input_depth);
int64_t id1 = end_index(od, output_depth, input_depth);
int64_t kd = id1 - id0;
for (const auto oh : c10::irange(output_height)) {
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
int64_t kh = ih1 - ih0;
for (const auto ow : c10::irange(output_width)) {
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
int64_t kw = iw1 - iw0;
scalar_t grad_delta = grad_output_ptr[od * output_width * output_height + oh * output_width + ow] / kd / kh / kw;
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
grad_input_ptr[id * input_height * input_width + ih * input_width + iw] += grad_delta;
}
}
}
}
}
}
}
});
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}
template <typename scalar_t>
void cpu_adaptive_avg_pool3d_backward_channels_last(
Tensor& grad_input_,
const Tensor& grad_output_) {
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto grad_input = grad_input_.contiguous(memory_format);
auto grad_output = grad_output_.contiguous(memory_format);
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
int64_t nbatch = grad_input.size(0);
int64_t channels = grad_input.size(1);
int64_t input_depth = grad_input.size(2);
int64_t input_height = grad_input.size(3);
int64_t input_width = grad_input.size(4);
int64_t output_depth = grad_output.size(2);
int64_t output_height = grad_output.size(3);
int64_t output_width = grad_output.size(4);
using Vec = vec::Vectorized<scalar_t>;
// parallel on dim N
at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
for (const auto n : c10::irange(begin, end)) {
scalar_t* grad_input_ptr = grad_input_data + n * input_depth * input_height * input_width * channels;
scalar_t* grad_output_ptr = grad_output_data + n * output_depth * output_height * output_width * channels;
for (const auto od : c10::irange(output_depth)) {
int64_t id0 = start_index(od, output_depth, input_depth);
int64_t id1 = end_index(od, output_depth, input_depth);
int64_t kd = id1 - id0;
for (const auto oh : c10::irange(output_height)) {
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
int64_t kh = ih1 - ih0;
for (const auto ow : c10::irange(output_width)) {
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
int64_t kw = iw1 - iw0;
scalar_t* gout = grad_output_ptr + od * output_depth * channels + oh * output_width * channels + ow * channels;
int64_t size = channels;
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
scalar_t* gin = grad_input_ptr + id * input_width * input_height * channels + ih * input_width * channels + iw * channels;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(kd * kh * kw));
gin_vec.store(gin + d);
}
for (; d < size; d++) {
gin[d] += gout[d] / kd / kh / kw;
}
}
}
}
}
}
}
}
});
if (!grad_input_.is_contiguous(memory_format)) {
grad_input_.copy_(grad_input);
}
}
void adaptive_avg_pool3d_kernel_impl(
Tensor& output,
const Tensor& input,
IntArrayRef output_size) {
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool3d", [&] {
using param_t = at::opmath_type<scalar_t>;
cpu_adaptive_avg_pool3d<scalar_t, /*accscalar_t*/param_t>(output, input, output_size);
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_avg_pool3d_channels_last", [&]{
cpu_adaptive_avg_pool3d_channels_last<scalar_t>(output, input, output_size);
});
break;
}
default:
TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
void adapative_avg_pool3d_backward_kernel_impl(
Tensor& grad_input,
const Tensor& grad_output) {
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool3d_backward", [&] {
cpu_adaptive_avg_pool3d_backward<scalar_t>(grad_input, grad_output);
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_avg_pool3d_backward_channels_last", [&]{
cpu_adaptive_avg_pool3d_backward_channels_last<scalar_t>(grad_input, grad_output);
cpu_adaptive_avg_pool_backward_channels_last<scalar_t>(grad_input, grad_output);
});
break;
}
@ -858,7 +413,5 @@ void adapative_avg_pool3d_backward_kernel_impl(
REGISTER_DISPATCH(adaptive_avg_pool2d_kernel, &adaptive_avg_pool2d_kernel_impl);
REGISTER_DISPATCH(adaptive_avg_pool2d_backward_kernel, &adapative_avg_pool2d_backward_kernel_impl);
REGISTER_DISPATCH(adaptive_avg_pool3d_kernel, &adaptive_avg_pool3d_kernel_impl);
REGISTER_DISPATCH(adaptive_avg_pool3d_backward_kernel, &adapative_avg_pool3d_backward_kernel_impl);
} // at::native

View File

@ -15,7 +15,7 @@ namespace at::native {
namespace {
template <typename scalar_t, typename accscalar_t>
void cpu_adaptive_max_pool2d(
void cpu_adaptive_max_pool(
const Tensor& output_,
const Tensor& indices_,
const Tensor& input_,
@ -83,13 +83,13 @@ void cpu_adaptive_max_pool2d(
template <typename scalar_t>
typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_max_pool2d_channels_last(
cpu_adaptive_max_pool_channels_last(
const Tensor& output_,
const Tensor& indices_,
const Tensor& input_,
IntArrayRef output_size) {
TORCH_CHECK(input_.ndimension() == 4,
"2d adaptive max pooling with channels last format supports tensors with 4 dims");
"adaptive max pooling with channels last format supports tensors with 4 dims");
auto memory_format = at::MemoryFormat::ChannelsLast;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
@ -200,13 +200,13 @@ cpu_adaptive_max_pool2d_channels_last(
template <typename scalar_t>
typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_max_pool2d_channels_last(
cpu_adaptive_max_pool_channels_last(
const Tensor& output_,
const Tensor& indices_,
const Tensor& input_,
IntArrayRef output_size) {
TORCH_CHECK(input_.ndimension() == 4,
"2d adaptive max pooling with channels last format supports tensors with 4 dims");
"adaptive max pooling with channels last format supports tensors with 4 dims");
auto memory_format = at::MemoryFormat::ChannelsLast;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
@ -340,7 +340,7 @@ cpu_adaptive_max_pool2d_channels_last(
}
template <typename scalar_t>
void cpu_adaptive_max_pool2d_backward(
void cpu_adaptive_max_pool_backward(
const Tensor& grad_input_,
const Tensor& grad_output_,
const Tensor& indices_) {
@ -386,12 +386,12 @@ void cpu_adaptive_max_pool2d_backward(
}
template <typename scalar_t>
void cpu_adaptive_max_pool2d_backward_channels_last(
void cpu_adaptive_max_pool_backward_channels_last(
const Tensor& grad_input_,
const Tensor& grad_output_,
const Tensor& indices_) {
TORCH_CHECK(grad_output_.ndimension() == 4,
"2d adaptive max pooling backward with channels last format supports tensors with 4 dims.");
"adaptive max pooling backward with channels last format supports tensors with 4 dims.");
auto memory_format = at::MemoryFormat::ChannelsLast;
auto grad_input = grad_input_.contiguous(memory_format);
auto grad_output = grad_output_.contiguous(memory_format);
@ -443,13 +443,13 @@ void adaptive_max_pool2d_kernel_impl(
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_max_pool2d", [&] {
using param_t = at::opmath_type<scalar_t>;
cpu_adaptive_max_pool2d<scalar_t, /*accscalar_t*/param_t>(output, indices, input, output_size);
cpu_adaptive_max_pool<scalar_t, /*accscalar_t*/param_t>(output, indices, input, output_size);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_max_pool2d_channels_last", [&]{
cpu_adaptive_max_pool2d_channels_last<scalar_t>(output, indices, input, output_size);
cpu_adaptive_max_pool_channels_last<scalar_t>(output, indices, input, output_size);
});
break;
}
@ -466,512 +466,13 @@ void adaptive_max_pool2d_backward_kernel_impl(
switch (grad_input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_max_pool2d_backward", [&] {
cpu_adaptive_max_pool2d_backward<scalar_t>(grad_input, grad_output, indices);
cpu_adaptive_max_pool_backward<scalar_t>(grad_input, grad_output, indices);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_max_pool2d_backward_channels_last", [&]{
cpu_adaptive_max_pool2d_backward_channels_last<scalar_t>(grad_input, grad_output, indices);
});
break;
}
default:
TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
template <typename scalar_t, typename accscalar_t>
void cpu_adaptive_max_pool3d(
const Tensor& output_,
const Tensor& indices_,
const Tensor& input_,
IntArrayRef output_size) {
auto input = input_.contiguous();
auto output = output_.contiguous();
auto indices = indices_.contiguous();
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
auto indices_data = indices.data_ptr<int64_t>();
int64_t ndim = input.ndimension();
// treat batch size and channels as one dimension
int64_t channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
int64_t input_depth = input.size(-3);
int64_t input_height = input.size(-2);
int64_t input_width = input.size(-1);
int64_t output_depth = output_size[0];
int64_t output_height = output_size[1];
int64_t output_width = output_size[2];
// parallel on dim of N, C
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width;
scalar_t* output_ptr = output_data + c * output_depth * output_height * output_width;
int64_t* indices_ptr = indices_data + c * output_depth * output_height * output_width;
for (const auto od : c10::irange(output_depth)) {
int64_t id0 = start_index(od, output_depth, input_depth);
int64_t id1 = end_index(od, output_depth, input_depth);
for (const auto oh : c10::irange(output_height)) {
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
for (const auto ow : c10::irange(output_width)) {
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
// compute local max
int64_t maxindex = id0 * input_height * input_width + ih0 * input_width + iw0;
accscalar_t maxval = -std::numeric_limits<accscalar_t>::infinity();
for (int64_t id = id0; id < id1; id ++) {
for (int64_t ih = ih0; ih < ih1; ih ++) {
for (int64_t iw = iw0; iw < iw1; iw ++) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
scalar_t val = input_ptr[index];
if ((val > maxval) || std::isnan(val)) {
maxval = val;
maxindex = index;
}
}
}
}
// set output to local max and store location of max
output_ptr[od * output_height * output_width + oh * output_width + ow] = maxval;
indices_ptr[od * output_height * output_width + oh * output_width + ow] = scalar_t(maxindex);
}
}
}
}
});
if (!output_.is_contiguous()) {
output_.copy_(output);
}
if (!indices_.is_contiguous()) {
indices_.copy_(indices);
}
}
template <typename scalar_t>
typename std::enable_if_t<std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_max_pool3d_channels_last(
const Tensor& output_,
const Tensor& indices_,
const Tensor& input_,
IntArrayRef output_size) {
TORCH_CHECK(input_.ndimension() == 5,
"3d adaptive max pooling with channels last format supports tensors with 5 dims");
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
auto indices = indices_.contiguous(memory_format);
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
auto indices_data = indices.data_ptr<int64_t>();
int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_depth = input.size(2);
int64_t input_height = input.size(3);
int64_t input_width = input.size(4);
int64_t output_depth = output_size[0];
int64_t output_height = output_size[1];
int64_t output_width = output_size[2];
using Vec = vec::Vectorized<scalar_t>;
using integer_t = vec::int_same_size_t<scalar_t>;
using iVec = vec::Vectorized<integer_t>;
// for the convience of vectorization, use integer of the same size of scalar_t,
// e.g. int32_t for float, int64_t for double
// need to make sure doesn't overflow
TORCH_CHECK(input_height * input_width <= std::numeric_limits<integer_t>::max());
// parallel on dim of N, H, W
at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t od = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
int64_t size = channels;
int64_t len = size - (size % Vec::size());
// temp buffer holding index with integer_t
auto index_buffer = std::make_unique<integer_t []>(len);
for (const auto i : c10::irange(begin, end)) {
int64_t id0 = start_index(od, output_depth, input_depth);
int64_t id1 = end_index(od, output_depth, input_depth);
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
scalar_t* out = output_data + i * channels;
int64_t* ind = indices_data + i * channels;
// Pass I: init out lane
iVec index0_vec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
Vec out_vec = Vec(-std::numeric_limits<scalar_t>::infinity());
int64_t d1 = 0;
for (; d1 < len; d1 += Vec::size()) {
index0_vec.store(index_buffer.get() + d1);
out_vec.store(out + d1);
}
for (; d1 < size; d1++) {
ind[d1] = id0 * input_height * input_width + ih0 * input_width + iw0;
out[d1] = -std::numeric_limits<scalar_t>::infinity();
}
// Pass II: compute local max
for (int64_t id = id0; id < id1; id ++) {
for (int64_t ih = ih0; ih < ih1; ih ++) {
for (int64_t iw = iw0; iw < iw1; iw ++) {
scalar_t* in = input_data + n * input_depth * input_height * input_width * channels +
id * input_height * input_width * channels + ih * input_width * channels + iw * channels;
int64_t d2 = 0;
for (; d2 < len; d2 += Vec::size()) {
iVec index_vec = iVec(id * input_height * input_width + ih * input_width + iw);
Vec val_vec = Vec::loadu(in + d2);
iVec maxindex_vec = iVec::loadu(index_buffer.get() + d2);
Vec maxval_vec = Vec::loadu(out + d2);
// true = all ones, false = all zeros
Vec mask = (val_vec > maxval_vec) | val_vec.isnan();
iVec imask = vec::cast<integer_t>(mask);
Vec out_vec = Vec::blendv(maxval_vec, val_vec, mask);
iVec ind_vec = iVec::blendv(maxindex_vec, index_vec, imask);
out_vec.store(out + d2);
ind_vec.store(index_buffer.get() + d2);
}
for (; d2 < size; d2++) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
scalar_t val = in[d2];
int64_t maxindex = ind[d2];
scalar_t maxval = out[d2];
bool mask = (val > maxval) || std::isnan(val);
out[d2] = mask ? val : maxval;
ind[d2] = mask ? index : maxindex;
}
}
}
}
// convert indice data type
vec::convert<integer_t, int64_t>(index_buffer.get(), ind, len);
// move on to next output index
data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous(memory_format)) {
output_.copy_(output);
}
if (!indices_.is_contiguous(memory_format)) {
indices_.copy_(indices);
}
}
template <typename scalar_t>
typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
cpu_adaptive_max_pool3d_channels_last(
const Tensor& output_,
const Tensor& indices_,
const Tensor& input_,
IntArrayRef output_size) {
TORCH_CHECK(input_.ndimension() == 5,
"3d adaptive max pooling with channels last format supports tensors with 5 dims");
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
auto indices = indices_.contiguous(memory_format);
auto input_data = input.data_ptr<BFloat16>();
auto output_data = output.data_ptr<BFloat16>();
auto indices_data = indices.data_ptr<int64_t>();
int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_depth = input.size(2);
int64_t input_height = input.size(3);
int64_t input_width = input.size(4);
int64_t output_depth = output_size[0];
int64_t output_height = output_size[1];
int64_t output_width = output_size[2];
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
using iVec = vec::Vectorized<int32_t>;
// need to make sure doesn't overflow
TORCH_CHECK(input_height * input_width <= std::numeric_limits<int32_t>::max());
// parallel on dim of N, H, W
at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t od = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
int64_t size = channels;
int64_t len = size - (size % bVec::size());
// temp buffer holding index with integer_t
auto index_buffer = std::make_unique<int32_t []>(len);
// temp buffer holding max value with float
auto max_arr = std::make_unique<float []>(size);
float* max = max_arr.get();
for (const auto i : c10::irange(begin, end)) {
int64_t id0 = start_index(od, output_depth, input_depth);
int64_t id1 = end_index(od, output_depth, input_depth);
int64_t ih0 = start_index(oh, output_height, input_height);
int64_t ih1 = end_index(oh, output_height, input_height);
int64_t iw0 = start_index(ow, output_width, input_width);
int64_t iw1 = end_index(ow, output_width, input_width);
BFloat16* out = output_data + i * channels;
int64_t* ind = indices_data + i * channels;
// Pass I: init out lane
iVec index0_ivec = iVec(id0 * input_height * input_width + ih0 * input_width + iw0);
fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
int64_t d1 = 0;
for (; d1 < len; d1 += fVec::size()) {
index0_ivec.store(index_buffer.get() + d1);
max_fvec.store(max + d1);
}
for (; d1 < size; d1++) {
ind[d1] = id0 * input_height * input_width + ih0 * input_width + iw0;
max[d1] = -std::numeric_limits<float>::infinity();
}
// Pass II: compute local max
for (int64_t id = id0; id < id1; id ++) {
for (int64_t ih = ih0; ih < ih1; ih ++) {
for (int64_t iw = iw0; iw < iw1; iw ++) {
BFloat16* in = input_data + n * input_depth * input_height * input_width * channels +
id * input_height * input_width * channels + ih * input_width * channels + iw * channels;
int64_t d2 = 0;
for (; d2 < len; d2 += bVec::size()) {
iVec index_ivec = iVec(id * input_height * input_width + ih * input_width + iw);
bVec val_bvec = bVec::loadu(in + d2);
fVec val_fvec0, val_fvec1;
std::tie(val_fvec0, val_fvec1) = convert_bfloat16_float(val_bvec);
iVec maxindex_ivec0 = iVec::loadu(index_buffer.get() + d2);
iVec maxindex_ivec1 = iVec::loadu(index_buffer.get() + d2 + iVec::size());
fVec maxval_fvec0 = fVec::loadu(max + d2);
fVec maxval_fvec1 = fVec::loadu(max + d2 + fVec::size());
// true = all ones, false = all zeros
fVec mask0 = (val_fvec0 > maxval_fvec0) | val_fvec0.isnan();
fVec mask1 = (val_fvec1 > maxval_fvec1) | val_fvec1.isnan();
iVec imask0 = vec::cast<int32_t>(mask0);
iVec imask1 = vec::cast<int32_t>(mask1);
fVec max_fvec0 = fVec::blendv(maxval_fvec0, val_fvec0, mask0);
fVec max_fvec1 = fVec::blendv(maxval_fvec1, val_fvec1, mask1);
iVec ind_ivec0 = iVec::blendv(maxindex_ivec0, index_ivec, imask0);
iVec ind_ivec1 = iVec::blendv(maxindex_ivec1, index_ivec, imask1);
max_fvec0.store(max + d2);
max_fvec1.store(max + d2 + fVec::size());
ind_ivec0.store(index_buffer.get() + d2);
ind_ivec1.store(index_buffer.get() + d2 + iVec::size());
}
for (; d2 < size; d2++) {
int64_t index = id * input_height * input_width + ih * input_width + iw;
float val = float(in[d2]);
int64_t maxindex = ind[d2];
float maxval = max[d2];
bool mask = (val > maxval) || std::isnan(val);
max[d2] = mask ? val : maxval;
ind[d2] = mask ? index : maxindex;
}
}
}
}
// Pass III: convert max values from float to bfloat16
int64_t d3 = 0;
for (; d3 < len; d3 += bVec::size()) {
fVec max_fvec0 = fVec::loadu(max + d3);
fVec max_fvec1 = fVec::loadu(max + d3 + fVec::size());
bVec max_bvec = convert_float_bfloat16(max_fvec0, max_fvec1);
max_bvec.store(out + d3);
}
for (; d3 < size; d3++) {
out[d3] = BFloat16(max[d3]);
}
// convert indice data type
vec::convert<int32_t, int64_t>(index_buffer.get(), ind, len);
// move on to next output index
data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous(memory_format)) {
output_.copy_(output);
}
if (!indices_.is_contiguous(memory_format)) {
indices_.copy_(indices);
}
}
template <typename scalar_t>
void cpu_adaptive_max_pool3d_backward(
const Tensor& grad_input_,
const Tensor& grad_output_,
const Tensor& indices_) {
auto grad_output = grad_output_.contiguous();
auto indices = indices_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto indices_data = indices.data_ptr<int64_t>();
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
int64_t ndim = grad_output.ndimension();
// treat batch size and channels as one dimension
int64_t channels = ndim == 3 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
int64_t input_depth = grad_input.size(-3);
int64_t input_height = grad_input.size(-2);
int64_t input_width = grad_input.size(-1);
int64_t output_depth = grad_output.size(-3);
int64_t output_height = grad_output.size(-2);
int64_t output_width = grad_output.size(-1);
// parallel on dim of N, C
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
scalar_t* grad_output_ptr = grad_output_data + c * output_depth * output_height * output_width;
int64_t* indices_ptr = indices_data + c * output_depth * output_height * output_width;
for (const auto od : c10::irange(output_depth)) {
for (const auto oh : c10::irange(output_height)) {
for (const auto ow : c10::irange(output_width)) {
// retrieve position of max
int64_t index = od * output_height * output_width + oh * output_width + ow;
int64_t maxindex = indices_ptr[index];
// update gradient
grad_input_ptr[maxindex] += grad_output_ptr[index];
}
}
}
}
});
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}
template <typename scalar_t>
void cpu_adaptive_max_pool3d_backward_channels_last(
const Tensor& grad_input_,
const Tensor& grad_output_,
const Tensor& indices_) {
TORCH_CHECK(grad_output_.ndimension() == 5,
"3d adaptive max pooling backward with channels last format supports tensors with 5 dims.");
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto grad_input = grad_input_.contiguous(memory_format);
auto grad_output = grad_output_.contiguous(memory_format);
auto indices = indices_.contiguous(memory_format);
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto indices_data = indices.data_ptr<int64_t>();
int64_t nbatch = grad_input.size(0);
int64_t channels = grad_input.size(1);
int64_t input_depth = grad_input.size(2);
int64_t input_height = grad_input.size(3);
int64_t input_width = grad_input.size(4);
int64_t output_depth = grad_output.size(2);
int64_t output_height = grad_output.size(3);
int64_t output_width = grad_output.size(4);
// parallel on dim N
at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
for (const auto n : c10::irange(begin, end)) {
scalar_t* grad_input_ptr = grad_input_data + n * input_depth * input_height * input_width * channels;
scalar_t* grad_output_ptr = grad_output_data + n * output_depth * output_height * output_width * channels;
int64_t* indices_ptr = indices_data + n * output_depth * output_height * output_width * channels;
for (const auto od : c10::irange(output_depth)) {
for (const auto oh : c10::irange(output_height)) {
for (const auto ow : c10::irange(output_width)) {
scalar_t* gout = grad_output_ptr + od * output_height * output_width * channels + oh * output_width * channels + ow * channels;
int64_t* ind = indices_ptr + od * output_height * output_width * channels + oh * output_width * channels + ow * channels;
// TODO: gcc vectorization
for (const auto c : c10::irange(channels)) {
int64_t maxindex = ind[c];
grad_input_ptr[maxindex * channels + c] += gout[c];
}
}
}
}
}
});
if (!grad_input_.is_contiguous(memory_format)) {
grad_input_.copy_(grad_input);
}
}
void adaptive_max_pool3d_kernel_impl(
const Tensor& output,
const Tensor& indices,
const Tensor& input,
IntArrayRef output_size) {
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_max_pool3d", [&] {
using param_t = at::opmath_type<scalar_t>;
cpu_adaptive_max_pool3d<scalar_t, /*accscalar_t*/param_t>(output, indices, input, output_size);
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, input.scalar_type(), "adaptive_max_pool3d_channels_last", [&]{
cpu_adaptive_max_pool3d_channels_last<scalar_t>(output, indices, input, output_size);
});
break;
}
default:
TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
void adaptive_max_pool3d_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
const Tensor& indices) {
// can't use grad_output memory format to switch here since grad_output might be NC11
switch (grad_input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_max_pool3d_backward", [&] {
cpu_adaptive_max_pool3d_backward<scalar_t>(grad_input, grad_output, indices);
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, grad_output.scalar_type(), "adaptive_max_pool3d_backward_channels_last", [&]{
cpu_adaptive_max_pool3d_backward_channels_last<scalar_t>(grad_input, grad_output, indices);
cpu_adaptive_max_pool_backward_channels_last<scalar_t>(grad_input, grad_output, indices);
});
break;
}
@ -984,7 +485,5 @@ void adaptive_max_pool3d_backward_kernel_impl(
REGISTER_DISPATCH(adaptive_max_pool2d_kernel, &adaptive_max_pool2d_kernel_impl);
REGISTER_DISPATCH(adaptive_max_pool2d_backward_kernel, &adaptive_max_pool2d_backward_kernel_impl);
REGISTER_DISPATCH(adaptive_max_pool3d_kernel, &adaptive_max_pool3d_kernel_impl);
REGISTER_DISPATCH(adaptive_max_pool3d_backward_kernel, &adaptive_max_pool3d_backward_kernel_impl);
} // at::native

View File

@ -14,7 +14,7 @@ namespace at::native {
namespace {
template <typename scalar_t>
void cpu_avg_pool2d(
void cpu_avg_pool(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH,
@ -101,7 +101,7 @@ void cpu_avg_pool2d(
template <typename scalar_t,
typename std::enable_if<!is_reduced_floating_point<scalar_t>::value, int>::type = 0>
void cpu_avg_pool2d_channels_last(
void cpu_avg_pool_channels_last(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH,
@ -110,7 +110,7 @@ void cpu_avg_pool2d_channels_last(
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_CHECK(input_.ndimension() == 4,
"2d average pooling with channels last format supports tensors with 4 dims");
"average pooling with channels last format supports tensors with 4 dims");
auto memory_format = at::MemoryFormat::ChannelsLast;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
@ -215,7 +215,7 @@ void cpu_avg_pool2d_channels_last(
template <typename scalar_t,
typename std::enable_if<is_reduced_floating_point<scalar_t>::value, int>::type = 0>
void cpu_avg_pool2d_channels_last(
void cpu_avg_pool_channels_last(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH,
@ -224,7 +224,7 @@ void cpu_avg_pool2d_channels_last(
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_CHECK(input_.ndimension() == 4,
"2d average pooling with channels last format supports tensors with 4 dims");
"average pooling with channels last format supports tensors with 4 dims");
auto memory_format = at::MemoryFormat::ChannelsLast;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
@ -347,7 +347,7 @@ void cpu_avg_pool2d_channels_last(
}
template <typename scalar_t>
void cpu_avg_pool2d_backward(
void cpu_avg_pool_backward(
const Tensor& grad_input_,
const Tensor& grad_output_,
int kW, int kH,
@ -415,7 +415,7 @@ void cpu_avg_pool2d_backward(
}
template <typename scalar_t>
void cpu_avg_pool2d_backward_channels_last(
void cpu_avg_pool_backward_channels_last(
const Tensor& grad_input_,
const Tensor& grad_output_,
int kW, int kH,
@ -463,7 +463,7 @@ void cpu_avg_pool2d_backward_channels_last(
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (ih1 - ih0) * (iw1 - iw0);
divide_factor = (ih1 - ih0) * (iw1 - iw0);
}
}
@ -505,13 +505,13 @@ void avg_pool2d_kernel_impl(
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND3(kLong, kBFloat16, kHalf, input.scalar_type(), "avg_pool2d", [&] {
cpu_avg_pool2d<scalar_t>(output, input, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
cpu_avg_pool<scalar_t>(output, input, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND3(kLong, kBFloat16, kHalf, input.scalar_type(), "avg_pool2d_channels_last", [&] {
cpu_avg_pool2d_channels_last<scalar_t>(output, input, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
cpu_avg_pool_channels_last<scalar_t>(output, input, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
});
break;
}
@ -531,13 +531,13 @@ void avg_pool2d_backward_kernel_impl(
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND3(kLong, kBFloat16, kHalf, grad_output.scalar_type(), "avg_pool2d_backward", [&] {
cpu_avg_pool2d_backward<scalar_t>(grad_input, grad_output, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
cpu_avg_pool_backward<scalar_t>(grad_input, grad_output, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND3(kLong, kBFloat16, kHalf, grad_output.scalar_type(), "avg_pool2d_backward_channels_last", [&] {
cpu_avg_pool2d_backward_channels_last<scalar_t>(grad_input, grad_output, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
cpu_avg_pool_backward_channels_last<scalar_t>(grad_input, grad_output, kW, kH, dW, dH, padW, padH, count_include_pad, divisor_override);
});
break;
}
@ -546,595 +546,9 @@ void avg_pool2d_backward_kernel_impl(
}
}
template <typename scalar_t>
void cpu_avg_pool3d(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH, int64_t kD,
int64_t dW, int64_t dH, int64_t dD,
int64_t padW, int64_t padH, int64_t padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
using acc_t = at::opmath_type<scalar_t>;
auto input = input_.contiguous();
auto output = output_.contiguous();
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t numel = output.numel();
int64_t ndim = input.ndimension();
// treat batch size and channels as one dimension
int64_t channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1);
int64_t input_depth = input.size(-3);
int64_t input_height = input.size(-2);
int64_t input_width = input.size(-1);
int64_t output_depth = output.size(-3);
int64_t output_height = output.size(-2);
int64_t output_width = output.size(-1);
// parallel on dim N, C, D, H, W
at::parallel_for(0, numel, 0, [&](int64_t begin, int64_t end) {
int64_t c = 0;
int64_t od = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, c, channels, od, output_depth, oh, output_height, ow, output_width);
for (const auto i : c10::irange(begin, end)) {
output_data[i] = static_cast<scalar_t>(0);
// local pointers
scalar_t* input_ptr = input_data + c * input_depth * input_height * input_width;
// compute the mean of the input image...
int64_t id0 = od * dD - padD;
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t id1 = std::min(id0 + kD, input_depth + padD);
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
id0 = std::max(id0, (int64_t) 0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
id1 = std::min(id1, input_depth);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
if (id0 >= id1 || ih0 >= ih1 || iw0 >= iw1) {
// move on to next output index
data_index_step(c, channels, od, output_depth, oh, output_height, ow, output_width);
continue;
}
acc_t sum = 0;
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
}
}
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
sum += input_ptr[id * input_height * input_width + ih * input_width + iw];
}
}
}
output_data[i] += scalar_t(sum / divide_factor);
// move on to next output index
data_index_step(c, channels, od, output_depth, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous()) {
output_.copy_(output);
}
}
template <typename scalar_t,
typename std::enable_if<!is_reduced_floating_point<scalar_t>::value, int>::type = 0>
void cpu_avg_pool3d_channels_last(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH, int64_t kD,
int64_t dW, int64_t dH, int64_t dD,
int64_t padW, int64_t padH, int64_t padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_CHECK(input_.ndimension() == 5,
"3d average pooling with channels last format supports tensors with 5 dims");
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
auto input_data = input.data_ptr<scalar_t>();
auto output_data = output.data_ptr<scalar_t>();
int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_depth = input.size(2);
int64_t input_height = input.size(3);
int64_t input_width = input.size(4);
int64_t output_depth = output.size(2);
int64_t output_height = output.size(3);
int64_t output_width = output.size(4);
using Vec = vec::Vectorized<scalar_t>;
// parallel on dim N, H, W
at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t od = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
int64_t size = channels;
int64_t len = size - (size % Vec::size());
for (const auto i : c10::irange(begin, end)) {
// compute the mean of the input image...
int64_t id0 = od * dD - padD;
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t id1 = std::min(id0 + kD, input_depth + padD);
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
id0 = std::max(id0, (int64_t) 0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
id1 = std::min(id1, input_depth);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
}
}
scalar_t* out = output_data + i * channels;
// Pass I: zero the out lane
int64_t d1 = 0;
for (; d1 < len; d1 += Vec::size()) {
Vec out_vec = Vec(scalar_t(0));
out_vec.store(out + d1);
}
for (; d1 < size; d1++) {
out[d1] = scalar_t(0);
}
if (id0 >= id1 || ih0 >= ih1 || iw0 >= iw1) {
// move on to next output index
data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
continue;
}
// Pass II: compute local sum
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
scalar_t* in = input_data + n * input_depth * input_height * input_width * channels +
id * input_height * input_width * channels + ih * input_width * channels + iw * channels;
int64_t d2 = 0;
for (; d2 < len; d2 += Vec::size()) {
Vec out_vec = Vec::loadu(out + d2) + Vec::loadu(in + d2);
out_vec.store(out + d2);
}
for (; d2 < size; d2++) {
out[d2] += in[d2];
}
}
}
}
// Pass III: compute local average
int64_t d3 = 0;
for (; d3 < len; d3 += Vec::size()) {
Vec out_vec = Vec::loadu(out + d3) / Vec(scalar_t(divide_factor));
out_vec.store(out + d3);
}
for (; d3 < size; d3++) {
out[d3] = out[d3] / divide_factor;
}
// move on to next output index
data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous(memory_format)) {
output_.copy_(output);
}
}
template <typename scalar_t,
typename std::enable_if<is_reduced_floating_point<scalar_t>::value, int>::type = 0>
void cpu_avg_pool3d_channels_last(
const Tensor& output_,
const Tensor& input_,
int64_t kW, int64_t kH, int64_t kD,
int64_t dW, int64_t dH, int64_t dD,
int64_t padW, int64_t padH, int64_t padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
TORCH_CHECK(input_.ndimension() == 5,
"3d average pooling with channels last format supports tensors with 5 dims");
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto input = input_.contiguous(memory_format);
auto output = output_.contiguous(memory_format);
auto input_data = input.data_ptr<BFloat16>();
auto output_data = output.data_ptr<BFloat16>();
int64_t nbatch = input.size(0);
int64_t channels = input.size(1);
int64_t input_depth = input.size(2);
int64_t input_height = input.size(3);
int64_t input_width = input.size(4);
int64_t output_depth = output.size(2);
int64_t output_height = output.size(3);
int64_t output_width = output.size(4);
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
// parallel on dim N, H, W
at::parallel_for(0, nbatch * output_depth * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t od = 0;
int64_t oh = 0;
int64_t ow = 0;
data_index_init(begin, n, nbatch, od, output_depth, oh, output_height, ow, output_width);
// temp buffer for sum, use float as accumulation type
// can't reuse output buffer to store sum since it is BFloat16
auto sum_arr = std::make_unique<float []>(channels);
float* sum = sum_arr.get();
int64_t size = channels;
for (const auto i : c10::irange(begin, end)) {
// compute the mean of the input image...
int64_t id0 = od * dD - padD;
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t id1 = std::min(id0 + kD, input_depth + padD);
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
id0 = std::max(id0, (int64_t) 0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
id1 = std::min(id1, input_depth);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
}
}
BFloat16* out = output_data + i * channels;
// Pass I: zero the out lane
int64_t d1 = 0;
for (; d1 < size - (size % fVec::size()); d1 += fVec::size()) {
fVec sum_fvec = fVec(float(0));
sum_fvec.store(sum + d1);
}
for (; d1 < size; d1++) {
sum[d1] = float(0);
}
if (id0 >= id1 || ih0 >= ih1 || iw0 >= iw1) {
// since we are not directly using output as the accumulation buffer,
// in case the kernel window is out of range, need to zero the output buffer here.
for (int64_t k = 0; k < size; k++) {
out[k] = 0;
}
// move on to next output index
data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
continue;
}
// Pass II: compute local sum
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
BFloat16* in = input_data + n * input_depth * input_height * input_width * channels +
id * input_height * input_width * channels + ih * input_width * channels + iw * channels;
int64_t d2 = 0;
for (; d2 < size - (size % bVec::size()); d2 += bVec::size()) {
bVec data_bvec = bVec::loadu(in + d2);
fVec data_fvec0, data_fvec1;
std::tie(data_fvec0, data_fvec1) = convert_bfloat16_float(data_bvec);
fVec sum_fvec0 = fVec::loadu(sum + d2) + data_fvec0;
fVec sum_fvec1 = fVec::loadu(sum + d2 + fVec::size()) + data_fvec1;
sum_fvec0.store(sum + d2);
sum_fvec1.store(sum + d2 + fVec::size());
}
for (; d2 < size; d2++) {
sum[d2] += float(in[d2]);
}
}
}
}
// Pass III: compute local average
int64_t d3 = 0;
for (; d3 < size - (size % bVec::size()); d3 += bVec::size()) {
fVec out_fvec0 = fVec::loadu(sum + d3) / fVec(float(divide_factor));
fVec out_fvec1 = fVec::loadu(sum + d3 + fVec::size()) / fVec(float(divide_factor));
bVec out_bvec = convert_float_bfloat16(out_fvec0, out_fvec1);
out_bvec.store(out + d3);
}
for (; d3 < size; d3++) {
out[d3] = BFloat16(sum[d3] / divide_factor);
}
// move on to next output index
data_index_step(n, nbatch, od, output_depth, oh, output_height, ow, output_width);
}
});
if (!output_.is_contiguous(memory_format)) {
output_.copy_(output);
}
}
template <typename scalar_t>
void cpu_avg_pool3d_backward(
const Tensor& grad_input_,
const Tensor& grad_output_,
int kW, int kH, int kD,
int dW, int dH, int dD,
int padW, int padH, int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
auto grad_output = grad_output_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
int64_t ndim = grad_output.ndimension();
// treat batch size and channels as one dimension
int64_t channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1);
int64_t input_depth = grad_input.size(-3);
int64_t input_height = grad_input.size(-2);
int64_t input_width = grad_input.size(-1);
int64_t output_depth = grad_output.size(-3);
int64_t output_height = grad_output.size(-2);
int64_t output_width = grad_output.size(-1);
// parallel on dim of N, C
at::parallel_for(0, channels, 0, [&](int64_t begin, int64_t end) {
for (const auto c : c10::irange(begin, end)) {
scalar_t* grad_input_ptr = grad_input_data + c * input_depth * input_height * input_width;
scalar_t* grad_output_ptr = grad_output_data + c * output_depth * output_height * output_width;
for (const auto od : c10::irange(output_depth)) {
for (const auto oh : c10::irange(output_height)) {
for (const auto ow : c10::irange(output_width)) {
int64_t id0 = od * dD - padD;
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t id1 = std::min(id0 + kD, input_depth + padD);
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
id0 = std::max(id0, (int64_t) 0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
}
}
scalar_t grad_delta = grad_output_ptr[od * output_height * output_width + oh * output_width + ow] / divide_factor;
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
grad_input_ptr[id * input_height * input_width + ih * input_width + iw] += grad_delta;
}
}
}
}
}
}
}
});
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}
template <typename scalar_t>
void cpu_avg_pool3d_backward_channels_last(
const Tensor& grad_input_,
const Tensor& grad_output_,
int kW, int kH, int kD,
int dW, int dH, int dD,
int padW, int padH, int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
auto memory_format = at::MemoryFormat::ChannelsLast3d;
auto grad_input = grad_input_.contiguous(memory_format);
auto grad_output = grad_output_.contiguous(memory_format);
auto grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
int64_t nbatch = grad_input.size(0);
int64_t channels = grad_input.size(1);
int64_t input_depth = grad_input.size(2);
int64_t input_height = grad_input.size(3);
int64_t input_width = grad_input.size(4);
int64_t output_depth = grad_output.size(2);
int64_t output_height = grad_output.size(3);
int64_t output_width = grad_output.size(4);
using Vec = vec::Vectorized<scalar_t>;
// parallel on dim N
at::parallel_for(0, nbatch, 0, [&](int64_t begin, int64_t end) {
for (const auto n : c10::irange(begin, end)) {
scalar_t* grad_input_ptr = grad_input_data + n * input_depth * input_height * input_width * channels;
scalar_t* grad_output_ptr = grad_output_data + n * output_height * output_width * channels;
for (const auto od : c10::irange(output_depth)) {
for (const auto oh : c10::irange(output_height)) {
for (const auto ow : c10::irange(output_width)) {
int64_t id0 = od * dD - padD;
int64_t ih0 = oh * dH - padH;
int64_t iw0 = ow * dW - padW;
int64_t id1 = std::min(id0 + kD, input_depth + padD);
int64_t ih1 = std::min(ih0 + kH, input_height + padH);
int64_t iw1 = std::min(iw0 + kW, input_width + padW);
int64_t pool_size = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
id0 = std::max(id0, (int64_t) 0);
ih0 = std::max(ih0, (int64_t) 0);
iw0 = std::max(iw0, (int64_t) 0);
id1 = std::min(id1, input_depth);
ih1 = std::min(ih1, input_height);
iw1 = std::min(iw1, input_width);
int64_t divide_factor;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (id1 - id0) * (ih1 - ih0) * (iw1 - iw0);
}
}
scalar_t* gout = grad_output_ptr + od * output_height * output_width * channels + oh * output_width * channels + ow * channels;
int64_t size = channels;
int64_t len = size - (size % Vec::size());
for (const auto id : c10::irange(id0, id1)) {
for (const auto ih : c10::irange(ih0, ih1)) {
for (const auto iw : c10::irange(iw0, iw1)) {
scalar_t* gin = grad_input_ptr + id * input_height * input_width * channels + ih * input_width * channels + iw * channels;
int64_t d = 0;
for (; d < len; d += Vec::size()) {
Vec gin_vec = Vec::loadu(gin + d) + Vec::loadu(gout + d) / Vec(scalar_t(divide_factor));
gin_vec.store(gin + d);
}
for (; d < size; d++) {
gin[d] += gout[d] / divide_factor;
}
}
}
}
}
}
}
}
});
if (!grad_input_.is_contiguous(memory_format)) {
grad_input_.copy_(grad_input);
}
}
void avg_pool3d_kernel_impl(
const Tensor& output,
const Tensor& input,
int64_t kW, int64_t kH, int64_t kD,
int64_t dW, int64_t dH, int64_t dD,
int64_t padW, int64_t padH, int64_t padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND3(kLong, kBFloat16, kHalf, input.scalar_type(), "avg_pool3d", [&] {
cpu_avg_pool3d<scalar_t>(output, input, kW, kH, kD, dW, dH, dD, padW, padH, padD, count_include_pad, divisor_override);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
AT_DISPATCH_FLOATING_TYPES_AND3(kLong, kBFloat16, kHalf, input.scalar_type(), "avg_pool3d_channels_last", [&] {
cpu_avg_pool3d_channels_last<scalar_t>(output, input, kW, kH, kD, dW, dH, dD, padW, padH, padD, count_include_pad, divisor_override);
});
break;
}
default:
TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
void avg_pool3d_backward_kernel_impl(
const Tensor& grad_input,
const Tensor& grad_output,
int kW, int kH, int kD,
int dW, int dH, int dD,
int padW, int padH, int padD,
bool count_include_pad,
c10::optional<int64_t> divisor_override) {
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
AT_DISPATCH_FLOATING_TYPES_AND3(kLong, kBFloat16, kHalf, grad_output.scalar_type(), "avg_pool3d_backward", [&] {
cpu_avg_pool3d_backward<scalar_t>(grad_input, grad_output, kW, kH, kD, dW, dH, dD, padW, padH, padD, count_include_pad, divisor_override);
});
break;
}
case at::MemoryFormat::ChannelsLast3d: {
AT_DISPATCH_FLOATING_TYPES_AND3(kLong, kBFloat16, kHalf, grad_output.scalar_type(), "avg_pool3d_backward_channels_last", [&] {
cpu_avg_pool3d_backward_channels_last<scalar_t>(grad_input, grad_output, kW, kH, kD, dW, dH, dD, padW, padH, padD, count_include_pad, divisor_override);
});
break;
}
default:
TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous");
}
}
} // anonymous namespace
REGISTER_DISPATCH(avg_pool2d_kernel, &avg_pool2d_kernel_impl);
REGISTER_DISPATCH(avg_pool2d_backward_kernel, &avg_pool2d_backward_kernel_impl);
REGISTER_DISPATCH(avg_pool3d_kernel, &avg_pool3d_kernel_impl);
REGISTER_DISPATCH(avg_pool3d_backward_kernel, &avg_pool3d_backward_kernel_impl);
} // at::native

View File

@ -52,8 +52,8 @@ typename std::enable_if<
grad_vec2 = grad_vec2 * fVec(opmath_t(-1.0));
}
if (weight_decay != 0.0){
grad_vec1 = vec::fmadd(param_vec1, fVec(scalar_t(weight_decay)), grad_vec1);
grad_vec2 = vec::fmadd(param_vec2, fVec(scalar_t(weight_decay)), grad_vec2);
grad_vec1 += param_vec1 * fVec(scalar_t(weight_decay));
grad_vec2 += param_vec2 * fVec(scalar_t(weight_decay));
}
if (momentum != 0.0) {
fVec momentum_vec1, momentum_vec2;
@ -61,16 +61,17 @@ typename std::enable_if<
momentum_vec1 = grad_vec1;
momentum_vec2 = grad_vec2;
} else {
momentum_vec1 = fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum));
momentum_vec2 = fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum));
momentum_vec1 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec1, momentum_vec1);
momentum_vec2 = vec::fmadd(fVec(scalar_t(1 - dampening)), grad_vec2, momentum_vec2);
momentum_vec1 =
fVec::loadu(momentum_buf_ptr + d) * fVec(scalar_t(momentum)) +
grad_vec1 * fVec(scalar_t(1 - dampening));
momentum_vec2 =
fVec::loadu(momentum_buf_ptr + d + fVec::size()) * fVec(scalar_t(momentum)) +
grad_vec2 * fVec(scalar_t(1 - dampening));
}
vec::convert_from_float<scalar_t>(momentum_vec1, momentum_vec2).store(momentum_buf_ptr + d);;
if (nesterov) {
grad_vec1 = vec::fmadd(momentum_vec1, fVec(scalar_t(momentum)), grad_vec1);
grad_vec2 = vec::fmadd(momentum_vec2, fVec(scalar_t(momentum)), grad_vec2);
grad_vec1 += momentum_vec1 * fVec(scalar_t(momentum));
grad_vec2 += momentum_vec2 * fVec(scalar_t(momentum));
} else {
grad_vec1 = momentum_vec1;
grad_vec2 = momentum_vec2;
@ -141,7 +142,7 @@ typename std::enable_if<
}
if (maximize) grad_vec = grad_vec * Vec(scalar_t(-1.0));
if (weight_decay != 0.0){
grad_vec = vec::fmadd(param_vec, Vec(scalar_t(weight_decay)), grad_vec);
grad_vec += param_vec * Vec(scalar_t(weight_decay));
}
if (momentum != 0.0) {
Vec momentum_vec;
@ -149,12 +150,12 @@ typename std::enable_if<
momentum_vec = grad_vec;
} else {
momentum_vec =
Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum));
momentum_vec = vec::fmadd(Vec(scalar_t(1 - dampening)), grad_vec, momentum_vec);
Vec::loadu(momentum_buf_ptr + d) * Vec(scalar_t(momentum)) +
grad_vec * Vec(scalar_t(1 - dampening));
}
momentum_vec.store(momentum_buf_ptr + d);
if (nesterov) {
grad_vec = vec::fmadd(momentum_vec, Vec(scalar_t(momentum)), grad_vec);
grad_vec += momentum_vec * Vec(scalar_t(momentum));
} else {
grad_vec = momentum_vec;
}

View File

@ -185,78 +185,11 @@ inline void tinygemm_kernel(
#if !defined(C10_MOBILE) && defined(__aarch64__)
#include <arm_neon.h>
inline float reduce(float32x4_t x) {
static inline float reduce(float32x4_t x) {
auto sum = vpaddq_f32(x, x);
return vgetq_lane_f32(vpaddq_f32(sum, sum), 0);
}
inline float32x4x2_t load_as_float32x4x2(const Half* ptr) {
float16x8_t f16_val = vld1q_f16(reinterpret_cast<const float16_t *>(ptr));
auto val_low = vcvt_f32_f16(vget_low_f16(f16_val));
auto val_high = vcvt_f32_f16(vget_high_f16(f16_val));
return {val_low, val_high};
}
inline float32x4_t load_as_float32x4(const Half* ptr) {
return vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(ptr)));
}
inline float32x4x2_t load_as_float32x4x2(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint16x8_t u16_val = vld1q_u16(reinterpret_cast<const uint16_t *>(ptr));
uint32x4_t int_low = vmovl_u16(vget_low_u16(u16_val));
uint32x4_t int_high = vmovl_u16(vget_high_u16(u16_val));
return {vreinterpretq_f32_u32(vshlq_u32(int_low, shift)), vreinterpretq_f32_u32(vshlq_u32(int_high, shift))};
}
inline float32x4_t load_as_float32x4(const BFloat16* ptr) {
int32x4_t shift = vdupq_n_s32(16);
uint32x4_t as_int = vmovl_u16(vld1_u16(reinterpret_cast<const uint16_t *>(ptr)));
return vreinterpretq_f32_u32(vshlq_u32(as_int, shift));
}
inline float32x4_t load_as_float32x4(const float* ptr) {
return vld1q_f32(ptr);
}
inline float32x4x2_t load_as_float32x4x2(const float* ptr) {
return {vld1q_f32(ptr), vld1q_f32(ptr + 4)};
}
template <int BLOCK_M, int BLOCK_N, typename T>
inline void tinygemm_kernel_(
const T* RESTRICT A,
const int8_t* RESTRICT B,
const T* RESTRICT scales,
T* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
for (const auto m : c10::irange(BLOCK_M)) {
float32x4_t c_val[BLOCK_N];
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
c_val[i] = vdupq_n_f32(0.0);
});
for (int k = 0; k < K; k += 8) {
auto a_val = load_as_float32x4x2(A + m * lda + k);
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
int16x8_t b_val = vmovl_s8(vld1_s8(B + i * ldb + k));
auto b_val_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_val)));
auto b_val_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_val)));
c_val[i] = vfmaq_f32(c_val[i], a_val.val[1], b_val_high);
c_val[i] = vfmaq_f32(c_val[i], a_val.val[0], b_val_low);
});
}
float32x4_t scale_val = load_as_float32x4(scales);
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i);
});
}
}
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const Half* RESTRICT A,
@ -267,33 +200,30 @@ inline void tinygemm_kernel(
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const BFloat16* RESTRICT A,
const int8_t* RESTRICT B,
const BFloat16* RESTRICT scales,
BFloat16* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
}
for (const auto m : c10::irange(BLOCK_M)) {
float32x4_t c_val[BLOCK_N];
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
c_val[i] = vdupq_n_f32(0.0);
});
for (int k = 0; k < K; k += 8) {
float16x8_t a_val = vld1q_f16(reinterpret_cast<const float16_t *>(A) + m * lda + k);
auto a_val_low = vcvt_f32_f16(vget_low_f16(a_val));
auto a_val_high = vcvt_f32_f16(vget_high_f16(a_val));
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
int16x8_t b_val = vmovl_s8(vld1_s8(B + i * ldb + k));
auto b_val_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_val)));
auto b_val_high = vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_val)));
c_val[i] = vfmaq_f32(c_val[i], a_val_high, b_val_high);
c_val[i] = vfmaq_f32(c_val[i], a_val_low, b_val_low);
});
}
template <int BLOCK_M, int BLOCK_N>
inline void tinygemm_kernel(
const float* RESTRICT A,
const int8_t* RESTRICT B,
const float* RESTRICT scales,
float* RESTRICT C,
int lda,
int ldb,
int ldc,
int K) {
tinygemm_kernel_<BLOCK_M, BLOCK_N>(A, B, scales, C, lda, ldb, ldc, K);
float32x4_t scale_val = vcvt_f32_f16(vld1_f16(reinterpret_cast<const float16_t *>(scales)));
c10::ForcedUnroll<BLOCK_N>{}([&](auto i) {
C[m * ldc + i] = reduce(c_val[i]) * vgetq_lane_f32(scale_val, i);
});
}
}
#endif

View File

@ -7,7 +7,6 @@
#include <ATen/TensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/tunable/Tunable.h>
#include <ATen/cuda/tunable/TunableGemm.h>
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
@ -892,108 +891,28 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
#ifdef USE_ROCM
auto tuning_ctx = at::cuda::tunable::getTuningContext();
if (tuning_ctx->IsTunableOpEnabled()) {
#define TUNABLE_DISPATCH(BLASOP_A, BLASOP_B) \
if (mat1.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e4m3fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
} \
else if (mat1.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
if (mat2.scalar_type() == ScalarType::Float8_e4m3fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e4m3fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
else if (mat2.scalar_type() == ScalarType::Float8_e5m2fnuz) { \
static at::cuda::tunable::ScaledGemmTunableOp< \
at::Float8_e5m2fnuz, at::Float8_e5m2fnuz, scalar_t, \
BLASOP_A, BLASOP_B> scaledgemm{}; \
scaledgemm(&params); \
} \
}
AT_DISPATCH_V2(out_dtype_, "_tunable_scaled_gemm", AT_WRAP([&] {
bool transa_ = ((args.transa != 'n') && (args.transa != 'N'));
bool transb_ = ((args.transb != 'n') && (args.transb != 'N'));
at::cuda::tunable::ScaledGemmParams<scalar_t> params;
params.transa = args.transa;
params.transb = args.transb;
params.m = args.m;
params.n = args.n;
params.k = args.k;
params.a = args.mata->data_ptr();
params.a_scale_ptr = scale_a ? scale_a->data_ptr() : nullptr;
params.lda = args.lda;
params.a_dtype = args.mata->scalar_type();
params.b = args.matb->data_ptr();
params.b_scale_ptr = scale_b ? scale_b->data_ptr() : nullptr;
params.ldb = args.ldb;
params.b_dtype = args.matb->scalar_type();
params.bias_ptr = bias ? bias->data_ptr(): nullptr;
params.bias_dtype = bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_;
params.c = args.result->data_ptr();
params.c_scale_ptr = scale_result ? scale_result->data_ptr() : nullptr;
params.ldc = args.result_ld;
params.c_dtype = out_dtype_;
params.amax_ptr = amax.data_ptr();
params.use_fast_accum = use_fast_accum;
if (transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::T)
}
else if (transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::T, at::cuda::tunable::BlasOp::N)
}
else if (!transa_ && transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::T)
}
else if (!transa_ && !transb_) {
TUNABLE_DISPATCH(at::cuda::tunable::BlasOp::N, at::cuda::tunable::BlasOp::N)
}
else {
TORCH_CHECK(false, "unreachable");
}
}),
kHalf, kBFloat16, kFloat8_e4m3fnuz, kFloat8_e5m2fnuz, AT_EXPAND(AT_FLOATING_TYPES));
#undef TUNABLE_DISPATCH
}
else
#endif
{
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
scale_a ? scale_a->data_ptr() : nullptr,
args.lda,
args.mata->scalar_type(),
args.matb->data_ptr(),
scale_b ? scale_b->data_ptr() : nullptr,
args.ldb,
args.matb->scalar_type(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
scale_result ? scale_result->data_ptr() : nullptr,
args.result_ld,
out_dtype_,
amax.data_ptr(),
use_fast_accum);
}
at::cuda::blas::scaled_gemm(
args.transa,
args.transb,
args.m,
args.n,
args.k,
args.mata->data_ptr(),
scale_a ? scale_a->data_ptr() : nullptr,
args.lda,
args.mata->scalar_type(),
args.matb->data_ptr(),
scale_b ? scale_b->data_ptr() : nullptr,
args.ldb,
args.matb->scalar_type(),
bias ? bias->data_ptr(): nullptr,
bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_,
args.result->data_ptr(),
scale_result ? scale_result->data_ptr() : nullptr,
args.result_ld,
out_dtype_,
amax.data_ptr(),
use_fast_accum);
#else
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
#endif

View File

@ -4,7 +4,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorShape.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/util/TypeCast.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -704,15 +703,12 @@ void split_with_sizes_copy_out_cuda(
IntArrayRef split_sizes,
int64_t dim,
TensorList out) {
const bool is_capturing = at::cuda::currentStreamCaptureStatusMayInitCtx() !=
at::cuda::CaptureStatus::None;
bool contiguous_no_cast = self.is_non_overlapping_and_dense();
for (const auto& t : out) {
contiguous_no_cast &= t.is_non_overlapping_and_dense();
contiguous_no_cast &= (t.dtype() == self.dtype());
}
// TODO(yifu): make the fast path work for CUDA graph
if (!is_capturing && contiguous_no_cast) {
if (contiguous_no_cast) {
// Perform equivalent checks performed by the composite impl
if (dim < 0) {
dim = at::maybe_wrap_dim(dim, self.dim());

View File

@ -29,30 +29,6 @@ void run_cudnn_SDP_fprop(
false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
}
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset) {
TORCH_CHECK(
false, "PyTorch was not compiled with cuDNN Flash Attention enabled!");
}
} // namespace native
} // namespace at
@ -97,22 +73,6 @@ using graph_and_tensors = std::tuple<
std::shared_ptr<fe::graph::Tensor_attributes> // Stats
>;
using graph_and_tensors_backward = std::tuple<
std::shared_ptr<fe::graph::Graph>,
std::shared_ptr<fe::graph::Tensor_attributes>, // Q,
std::shared_ptr<fe::graph::Tensor_attributes>, // K,
std::shared_ptr<fe::graph::Tensor_attributes>, // V,
std::shared_ptr<fe::graph::Tensor_attributes>, // Attn_scale
std::shared_ptr<fe::graph::Tensor_attributes>, // Seed,
std::shared_ptr<fe::graph::Tensor_attributes>, // Offset,
std::shared_ptr<fe::graph::Tensor_attributes>, // O,
std::shared_ptr<fe::graph::Tensor_attributes>, // dO,
std::shared_ptr<fe::graph::Tensor_attributes>, // stats,
std::shared_ptr<fe::graph::Tensor_attributes>, // dQ,
std::shared_ptr<fe::graph::Tensor_attributes>, // dK,,
std::shared_ptr<fe::graph::Tensor_attributes> // dV,
>;
#define MAX_MHA_DIM 4
struct MHAParams {
@ -218,7 +178,8 @@ struct MHACacheKeyWrapper : ParamsWrapper<MHAParams> {
template <typename T, typename KeyType>
struct MHAGraphCache {
std::unordered_map<KeyType, T, ParamsWrapperHash<KeyType>> engine_cache;
std::unordered_map<KeyType, graph_and_tensors, ParamsWrapperHash<KeyType>>
engine_cache;
// no mutexes here as caches are now thread local for v8, can also return a
// pointer to the Execution Plan if we know it will not be invalidated by
@ -241,8 +202,6 @@ struct MHAGraphCache {
// be thread safe across all engines see Limitations in
// https://docs.nvidia.com/deeplearning/cudnn/release-notes/index.html
thread_local MHAGraphCache<graph_and_tensors, MHACacheKeyWrapper> mhagraphcache;
thread_local MHAGraphCache<graph_and_tensors_backward, MHACacheKeyWrapper>
mhagraphbackwardcache;
auto build_graph_and_tensors(
int64_t b,
@ -268,12 +227,10 @@ auto build_graph_and_tensors(
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
@ -297,7 +254,7 @@ auto build_graph_and_tensors(
params.v_stride.begin(), params.v_stride.end())));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Attn_scale")
.set_name("attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
@ -319,7 +276,7 @@ auto build_graph_and_tensors(
.set_data_type(fe::DataType_t::INT32));
auto scaled_dot_product_flash_attention_options =
fe::graph::SDPA_attributes()
.set_name("CUDNN_SDPA")
.set_name("flash_attention")
.set_is_inference(return_softmaxstats == false)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale)
@ -330,12 +287,12 @@ auto build_graph_and_tensors(
}
auto seq_q = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seq_q")
.set_name("seq_q")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seq_kv")
.set_name("seq_kv")
.set_dim({b, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
@ -367,146 +324,7 @@ auto build_graph_and_tensors(
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
return std::make_tuple(
std::move(mha_graph),
std::move(Q),
std::move(K),
std::move(V),
std::move(attn_scale),
std::move(seed),
std::move(offset),
std::move(O),
std::move(Stats));
}
auto build_graph_and_tensors_backward(
int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset,
cudnnHandle_t& handle,
MHAParams& params) {
auto dtype = fe::DataType_t::HALF;
if (q.scalar_type() == kBFloat16) {
dtype = fe::DataType_t::BFLOAT16;
}
auto mha_graph = std::make_shared<fe::graph::Graph>();
// We're baking in float accumulation and scale types
// in theory the graph may support other types, but they
// have not been tested
mha_graph->set_io_data_type(dtype)
.set_intermediate_data_type(fe::DataType_t::FLOAT)
.set_compute_data_type(fe::DataType_t::FLOAT);
auto Q = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Q")
.set_dim(std::vector<int64_t>(q.sizes().begin(), q.sizes().end()))
.set_stride(
std::vector<int64_t>(q.strides().begin(), q.strides().end())));
auto K = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("K")
.set_dim(std::vector<int64_t>(k.sizes().begin(), k.sizes().end()))
.set_stride(
std::vector<int64_t>(k.strides().begin(), k.strides().end())));
auto V = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("V")
.set_dim(std::vector<int64_t>(v.sizes().begin(), v.sizes().end()))
.set_stride(
std::vector<int64_t>(v.strides().begin(), v.strides().end())));
auto attn_scale =
mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Attn_scale")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_is_pass_by_value(true)
.set_data_type(fe::DataType_t::FLOAT));
auto Seed = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Seed")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto Offset = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Offset")
.set_dim({1, 1, 1, 1})
.set_stride({1, 1, 1, 1})
.set_data_type(fe::DataType_t::INT32));
auto O = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("O")
.set_dim(std::vector<int64_t>(o.sizes().begin(), o.sizes().end()))
.set_stride(
std::vector<int64_t>(o.strides().begin(), o.strides().end())));
auto STATS = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("Stats")
.set_dim(std::vector<int64_t>(
softmaxstats.sizes().begin(), softmaxstats.sizes().end()))
.set_stride(std::vector<int64_t>(
softmaxstats.strides().begin(), softmaxstats.strides().end()))
.set_data_type(fe::DataType_t::FLOAT));
auto DO = mha_graph->tensor(
fe::graph::Tensor_attributes()
.set_name("DO")
.set_dim(std::vector<int64_t>(dO.sizes().begin(), dO.sizes().end()))
.set_stride(
std::vector<int64_t>(dO.strides().begin(), dO.strides().end())));
auto sdpa_backward_options = fe::graph::SDPA_backward_attributes()
.set_name("CUDNN_SDPA_BACKWARD")
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
if (dropout_probability != 0.0f) {
sdpa_backward_options.set_dropout(dropout_probability, Seed, Offset);
}
auto [DQ, DK, DV] =
mha_graph->sdpa_backward(Q, K, V, O, DO, STATS, sdpa_backward_options);
DQ->set_output(true)
.set_dim(std::vector<int64_t>(dQ.sizes().begin(), dQ.sizes().end()))
.set_stride(
std::vector<int64_t>(dQ.strides().begin(), dQ.strides().end()));
DK->set_output(true)
.set_dim(std::vector<int64_t>(dK.sizes().begin(), dK.sizes().end()))
.set_stride(
std::vector<int64_t>(dK.strides().begin(), dK.strides().end()));
DV->set_output(true)
.set_dim(std::vector<int64_t>(dV.sizes().begin(), dV.sizes().end()))
.set_stride(
std::vector<int64_t>(dV.strides().begin(), dV.strides().end()));
AT_CUDNN_FRONTEND_CHECK(mha_graph->validate());
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle));
AT_CUDNN_FRONTEND_CHECK(
mha_graph->create_execution_plans({fe::HeurMode_t::A}));
AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle));
AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle));
return std::make_tuple(
std::move(mha_graph),
std::move(Q),
std::move(K),
std::move(V),
std::move(attn_scale),
std::move(Seed),
std::move(Offset),
std::move(O),
std::move(DO),
std::move(STATS),
std::move(DQ),
std::move(DK),
std::move(DV));
mha_graph, Q, K, V, attn_scale, seed, offset, O, Stats);
}
void run_cudnn_SDP_fprop(
@ -589,92 +407,11 @@ void run_cudnn_SDP_fprop(
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(
TORCH_INTERNAL_ASSERT(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mhagraphcache.update(key, graph_and_tensors_values);
}
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset) {
cudnnHandle_t handle = getCudnnHandle();
auto key = MHACacheKeyWrapper(
b, h, s_q, s_kv, d, q, k, v, dropout_probability, is_causal, true);
auto graph_and_tensors_backward_ptr = mhagraphbackwardcache.find(key);
graph_and_tensors_backward graph_and_tensors_backward_values;
if (graph_and_tensors_backward_ptr) {
graph_and_tensors_backward_values = *graph_and_tensors_backward_ptr;
} else {
graph_and_tensors_backward_values = build_graph_and_tensors_backward(
b,
h,
s_q,
s_kv,
d,
scaling_factor,
is_causal,
dropout_probability,
q,
k,
v,
o,
dO,
softmaxstats,
dQ,
dK,
dV,
dropoutseed,
dropoutoffset,
handle,
key.pod);
}
auto
[mha_graph, Q, K, V, attn_scale, Seed, Offset, O, Do, Stats, Dq, Dk, Dv] =
graph_and_tensors_backward_values;
std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*>
variant_pack = {// inputs
{Q, q.data_ptr()},
{K, k.data_ptr()},
{V, v.data_ptr()},
{O, o.data_ptr()},
{Do, dO.data_ptr()},
{Stats, softmaxstats.data_ptr()},
// outputs
{Dq, dQ.data_ptr()},
{Dk, dK.data_ptr()},
{Dv, dV.data_ptr()},
// pass by value
{attn_scale, &scaling_factor}};
if (dropout_probability != 0.0f) {
variant_pack[Seed] = dropoutseed.data_ptr();
variant_pack[Offset] = dropoutoffset.data_ptr();
}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(!workspace_size || workspace_ptr.get());
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mhagraphbackwardcache.update(key, graph_and_tensors_backward_values);
}
} // namespace native
} // namespace at

View File

@ -21,27 +21,5 @@ void run_cudnn_SDP_fprop(
Tensor& o,
Tensor& dropoutseed,
Tensor& dropoutoffset);
void run_cudnn_SDP_bprop(
int64_t b,
int64_t h,
int64_t s_q,
int64_t s_kv,
int64_t d,
float scaling_factor,
bool is_causal,
float dropout_probability,
const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& o,
const Tensor& dO,
const Tensor& softmaxstats,
Tensor& dQ,
Tensor& dK,
Tensor& dV,
const Tensor& dropoutseed,
const Tensor& dropoutoffset);
} // namespace native
}
} // namespace at

View File

@ -198,40 +198,24 @@ Tensor mkldnn_reorder_conv3d_weight(
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
int64_t groups) {
mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_conv3d_weight");
const auto padding_expanded = expand_param_if_needed(padding, "padding", 3);
const auto stride_expanded = expand_param_if_needed(stride, "stride", 3);
const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 3);
ideep::dims src_dims = ideep::dims();
bool is_channels_last = false;
auto memory_format = at::MemoryFormat::Contiguous;
if (input_size.has_value()) {
src_dims = input_size.value().vec();
// if has input size, we always use channels last.
is_channels_last = true;
memory_format = at::MemoryFormat::ChannelsLast3d;
}
auto w = itensor_from_mkldnn(self);
auto self_ = self.is_mkldnn() ? self : self.contiguous(memory_format);
auto w = itensor_from_tensor(self_);
auto desc = ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
stride_expanded,
padding_expanded,
padding_expanded,
dilation_expanded,
groups,
ideep::algorithm::convolution_direct,
ideep::prop_kind::forward,
w.get_data_type(),
src_dims,
ideep::attr_t(),
is_channels_last);
auto desc =
ideep::convolution_forward::expected_weights_desc(
w.get_dims(),
w.get_data_type(),
stride_expanded,
padding_expanded,
padding_expanded,
dilation_expanded,
groups,
ideep::algorithm::convolution_direct);
ideep::tensor result;
result.init(desc);
result.feed_from(w);
@ -239,21 +223,6 @@ Tensor mkldnn_reorder_conv3d_weight(
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
}
static Tensor mkldnn_reorder_conv_weight(
const Tensor& self,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
TORCH_CHECK((self.dim() == 4 || self.dim() == 5), "mkldnn_reorder_conv_weight only supports conv2d and conv3d");
if (self.dim() == 4) {
return at::native::mkldnn_reorder_conv2d_weight(self, padding, stride, dilation, groups, input_size);
} else {
return at::native::mkldnn_reorder_conv3d_weight(self, padding, stride, dilation, groups, input_size);
}
}
static Tensor mkldnn_reorder_linear_weight(
const Tensor& self,
c10::optional<int64_t> batch_size_opt) {
@ -517,7 +486,7 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) {
TORCH_FN(mkldnn_reorder_linear_weight));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_convolution_weight"),
TORCH_FN(mkldnn_reorder_conv_weight));
TORCH_FN(mkldnn_reorder_conv2d_weight));
m.impl(
TORCH_SELECTIVE_NAME("mkldnn::_reorder_mkldnn_rnn_layer_weight"),
TORCH_FN(mkldnn_reorder_mkldnn_rnn_layer_weight));
@ -548,8 +517,7 @@ Tensor mkldnn_reorder_conv3d_weight(
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
c10::OptionalArrayRef<int64_t> input_size) {
int64_t groups) {
TORCH_CHECK(false, "mkldnn_reorder_conv3d_weight: MKL-DNN build is disabled");
}

View File

@ -25,10 +25,6 @@ typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
-(MPSGraphTensor * _Nonnull) fastFourierTransformWithTensor:(MPSGraphTensor * _Nonnull) tensor
axes:(NSArray<NSNumber *> * _Nonnull) axes
descriptor:(MPSGraphFFTDescriptor * _Nonnull) descriptor

View File

@ -2,7 +2,6 @@
#pragma once
#include <initializer_list>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
@ -72,9 +71,6 @@ static inline std::string getMPSTypeString(const Tensor& t, bool short_name = fa
return getMPSTypeString(t.scalar_type(), short_name);
}
std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
static inline std::string scalarToMetalTypeString(const Tensor& t) {
return scalarToMetalTypeString(t.scalar_type());
}
NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
@ -333,30 +329,6 @@ inline bool is_dense_in_storage(const at::Tensor& t) {
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
}
class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src, unsigned nparams_ = 0): shaderSource(src), nparams(nparams_) {}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname);
}
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname);
}
private:
id<MTLComputePipelineState> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
id<MTLLibrary> getLibrary();
id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
id<MTLLibrary> compileLibrary(const std::string& src);
std::string shaderSource;
unsigned nparams;
id<MTLLibrary> library = nil;
std::unordered_map<std::string, id<MTLLibrary>> libMap;
std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
};
static inline void mtl_setBuffer(id<MTLComputeCommandEncoder> encoder, const Tensor& t, unsigned idx) {
[encoder setBuffer:getMTLBufferStorage(t)
offset:t.storage_offset() * t.element_size()
@ -419,8 +391,4 @@ inline bool supportedFloatingType(const Tensor& t) {
return supportedFloatingType(t.scalar_type());
}
inline bool needsGather(const Tensor& t) {
return !t.is_contiguous() || t.storage_offset();
}
} // namespace at::native::mps

View File

@ -6,7 +6,6 @@
#include <ATen/native/mps/MPSGraphSonomaOps.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
#include <fmt/format.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -368,8 +367,9 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
TORCH_CHECK(src.is_mps(), "Placeholder storage has not been allocated on MPS device!");
// extract the pointer to MTLBuffer from the Tensor's storage
id<MTLBuffer> srcBuf = getMTLBufferStorage(src);
bool sliceViewTensor = canSliceViewTensor(src, mpsShape);
// a view tensor could be contiguous (e.g., slice ops) or non-contiguous (e.g., transpose())
if (needsGather(src) && gatherTensorData) {
if ((!src.is_contiguous() || (src.storage_offset() && !sliceViewTensor)) && gatherTensorData) {
Tensor emptyShell = Tensor();
// use "_tensor" from Placeholder to retain view's output during its usage in other ops
_tensor = gatherViewTensor(src, emptyShell);
@ -389,9 +389,13 @@ Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor,
const auto scalar_type = _tensor.scalar_type();
dataType = _tensor.dim() == 0 ? getMPSScalarType(scalar_type) : getMPSDataType(scalar_type);
}
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
shape:mpsShape ? mpsShape : getMPSShape(_tensor)
dataType:dataType] autorelease];
if (src.is_contiguous() && src.storage_offset() && sliceViewTensor) {
_value = getMPSGraphTensorDataForView(src, mpsShape, dataType);
} else {
_value = [[[MPSGraphTensorData alloc] initWithMTLBuffer:srcBuf
shape:mpsShape ? mpsShape : getMPSShape(_tensor)
dataType:dataType] autorelease];
}
TORCH_INTERNAL_ASSERT(_value);
_placeholder = mpsGraphTensor;
@ -612,74 +616,4 @@ id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEnco
return kernelDataOffsets;
}
id<MTLLibrary> MetalShaderLibrary::getLibrary() {
if (C10_UNLIKELY(!library)) {
TORCH_INTERNAL_ASSERT(nparams == 0);
library = compileLibrary(shaderSource);
}
return library;
}
id<MTLLibrary> MetalShaderLibrary::getLibrary(const std::initializer_list<std::string>& params) {
TORCH_INTERNAL_ASSERT(nparams == params.size());
std::string key = "";
for (auto p : params) {
key += ":" + p;
}
auto lib = libMap[key];
if (lib) {
return lib;
}
auto it = params.begin();
switch (nparams) {
case 1:
lib = compileLibrary(fmt::format(shaderSource, *it));
break;
case 2: {
auto& first = *it++;
auto& second = *it;
lib = compileLibrary(fmt::format(shaderSource, first, second));
break;
}
case 3: {
auto& first = *it++;
auto& second = *it++;
auto& third = *it;
lib = compileLibrary(fmt::format(shaderSource, first, second, third));
break;
}
default:
TORCH_INTERNAL_ASSERT(false, "Unsupported number of paramaters ", nparams);
}
return libMap[key] = lib;
}
id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1
: MTLLanguageVersion2_3];
auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
auto device = MPSDevice::getInstance()->device();
library = [device newLibraryWithSource:str options:options error:&error];
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
return library;
}
id<MTLComputePipelineState> MetalShaderLibrary::getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname) {
auto key = fmt::format("{}:{}", reinterpret_cast<void*>(lib), fname);
auto cpl = cplMap[key];
if (cpl) {
return cpl;
}
NSError* error = nil;
id<MTLFunction> func = [lib newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func, "Failed to create function state object for: ", fname);
cpl = [[lib device] newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(cpl, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
return cplMap[key] = cpl;
}
} // namespace at::native::mps

View File

@ -814,7 +814,7 @@ static void elu_variants_out_mps(const Tensor& self,
auto resultMemFormat = result.suggest_memory_format();
bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && result.is_contiguous(resultMemFormat));
Tensor out;
if (executeGatherOp) {
if (executeGatherOp && resultMemFormat == MemoryFormat::ChannelsLast) {
out = at::empty_like(result, MemoryFormat::Contiguous);
}

View File

@ -6,8 +6,6 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/operations/BinaryKernel.h>
// For MTLLanguageVersion_3_1
#include <ATen/native/mps/MPSGraphSonomaOps.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -24,7 +22,7 @@
namespace at::native {
namespace mps {
static MetalShaderLibrary lib(R"BINARY_METAL(
static const char* METAL_BINARY = R"BINARY_METAL(
#include <metal_stdlib>
using namespace metal;
@ -192,25 +190,24 @@ kernel void nextafter_kernel(constant void * input_ [[buffer(0)]],
device void * out_ [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
auto input = *(constant T*)((constant uint8_t*)input_ + offsets[tid].y);
auto other = *(constant T*)((constant uint8_t*)other_ + offsets[tid].z);
#if __METAL_VERSION__ >= 310
*out = nextafter(input, other);
#else
if (input == other) {
*out = input;
} else if (isnan(input) || isnan(other)) {
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
if (*input == *other)
{
*out = *other;
}
else if (isnan(*input) || isnan(*other))
{
*out = NAN;
} else if (input == 0) {
constexpr auto one = as_type<T>(static_cast<U>(1));
*out = other > 0 ? one : -one;
} else {
U bits = as_type<U>(input);
(input > 0) ^ (input > other) ? bits++ : bits--;
}
else
{
U bits = as_type<U>(*input);
bits = bits + ((*other > *input) ? 1 : -1);
*out = as_type<T>(bits);
}
#endif
}
#define REGISTER_NEXTAFTER_OP(DTYPE, UTYPE) \
@ -252,7 +249,43 @@ kernel void complex_kernel<DTYPE>( \
REGISTER_COMPLEX_OUT_OP(float);
REGISTER_COMPLEX_OUT_OP(half);
)BINARY_METAL");
)BINARY_METAL";
using namespace mps;
static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> binaryLibrary = nil;
if (binaryLibrary) {
return binaryLibrary;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_BINARY encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]);
return binaryLibrary;
}
static id<MTLComputePipelineState> binaryPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}
NSError* error = nil;
id<MTLLibrary> binaryLib = compileBinaryOpsLibrary(device);
id<MTLFunction> binaryFunc = [binaryLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(binaryFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:binaryFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
psoCache[kernel] = pso;
return pso;
}
static void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
@ -269,10 +302,10 @@ static void binary_mps_impl(TensorIteratorBase& iter, const std::string func_nam
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input);
const std::string kernel = func_name + "_" + scalarToMetalTypeString(input.scalar_type());
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter);
id<MTLComputePipelineState> binaryPSO = lib.getPipelineStateForFunc(kernel);
id<MTLComputePipelineState> binaryPSO = binaryPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {input, other});

View File

@ -97,7 +97,7 @@ static void binaryOpTensor(const Tensor& self,
Tensor output = output_;
bool needsCopyToOutput = false;
if (needsGather(output_) || (output_.is_view() && (self.is_alias_of(output_) || other.is_alias_of(output_)))) {
if (!output_.is_contiguous() || (output_.is_view() && (self.is_alias_of(output_) || other.is_alias_of(output_)))) {
output = at::empty(output_.sizes(), output_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
needsCopyToOutput = true;
}

View File

@ -12,7 +12,7 @@
namespace at::native {
namespace mps {
static MetalShaderLibrary lib(R"METAL(
static const char* BITWISE_OPS_TEMPLATE = R"METAL(
kernel void bitwise_and_tensor(constant uint& length [[buffer(0)]],
device {0} *out [[buffer(1)]],
@ -90,8 +90,7 @@ kernel void bitwise_not(constant uint& length [[buffer(0)]],
}}
out[offset] = ~a[offset];
}}
)METAL",
3);
)METAL";
static const std::string& getMetalType(const c10::ScalarType& t) {
// Mapping from c10::ScalarType to integral type that can be used for bitwise ops
@ -118,12 +117,48 @@ static const std::string& getMetalType(const c10::Scalar& s) {
return getMetalType(s.type());
}
template <typename ScalarOrTensor>
static id<MTLComputePipelineState> getCPLState(const Tensor& t1,
const Tensor& t2,
const ScalarOrTensor& t3,
static id<MTLLibrary> compileBitwiseOpsLibrary(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
const std::string& t3) {
auto key = t1 + t2 + t3;
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
auto it = libMap.find(key);
if (it != libMap.end()) {
return it->second;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
auto rc =
[device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(BITWISE_OPS_TEMPLATE, t1, t2, t3).c_str()]
options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
libMap[key] = rc;
return rc;
}
static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
const std::string& t3,
const std::string& fname) {
return lib.getPipelineStateForFunc(fname, {getMetalType(t1), getMetalType(t2), getMetalType(t3)});
auto key = t1 + t2 + t3 + fname;
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
auto it = cplMap.find(key);
if (it != cplMap.end()) {
return it->second;
}
NSError* error = nil;
auto library = compileBitwiseOpsLibrary(device, t1, t2, t3);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func != nil, "Can't get function ", fname);
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
cplMap[key] = rc;
return rc;
}
static void handle_tensor_tensor_binary_op(const Tensor& self,
@ -132,7 +167,8 @@ static void handle_tensor_tensor_binary_op(const Tensor& self,
const std::string& kernel_name) {
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
auto cplState = getCPLState(output, self, other, kernel_name);
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
uint32_t length = output.numel();
if (length == 0) {
return;
@ -162,7 +198,8 @@ static void handle_tensor_scalar_binary_op(const Tensor& self,
const std::string& kernel_name) {
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
auto cplState = getCPLState(output, self, other, kernel_name);
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(other), kernel_name);
uint64_t sval = other.to<int64_t>();
uint32_t length = output.numel();
if (length == 0) {
@ -199,7 +236,7 @@ static void _bitwise_op_out_mps(const Tensor& self,
auto output_size = at::infer_size_dimvector(self.sizes(), other.sizes());
resize_output(output, output_size);
if (needsGather(output)) {
if (!output.is_contiguous()) {
output = output.contiguous();
needs_output_copy = true;
}
@ -240,7 +277,7 @@ static void _bitwise_not_out_mps(const Tensor& self, const Tensor& output_) {
bool needs_output_copy = false;
resize_output(output, self.sizes());
if (needsGather(output)) {
if (!output.is_contiguous()) {
output = output.contiguous();
needs_output_copy = true;
}
@ -259,7 +296,8 @@ static void _bitwise_not_out_mps(const Tensor& self, const Tensor& output_) {
}
using namespace at::mps;
MPSStream* stream = getCurrentMPSStream();
auto cplState = getCPLState(output, self, self, "bitwise_not");
id<MTLComputePipelineState> cplState = getCPLState(
MPSDevice::getInstance()->device(), getMetalType(output), getMetalType(self), getMetalType(self), "bitwise_not");
dispatch_sync(stream->queue(), ^() {
getMPSProfiler().beginProfileKernel(cplState, "bitwise_not", {self});

View File

@ -17,7 +17,7 @@
namespace at::native {
namespace mps {
static MetalShaderLibrary lib(R"BUCKETIZE_METAL(
static const char* METAL_BUCKETIZATION = R"BUCKETIZE_METAL(
#include <metal_stdlib>
using namespace metal;
@ -194,7 +194,44 @@ REGISTER_SEARCHSORTED_OP(int, long);
REGISTER_SEARCHSORTED_OP(long, int);
REGISTER_SEARCHSORTED_OP(long, long);
)BUCKETIZE_METAL");
)BUCKETIZE_METAL";
static id<MTLLibrary> compileBucketizationOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> bucketizationLibrary = nil;
if (bucketizationLibrary) {
return bucketizationLibrary;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
bucketizationLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_BUCKETIZATION
encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(
bucketizationLibrary, "Failed to create metal bucketization library, error: ", [[error description] UTF8String]);
return bucketizationLibrary;
}
static id<MTLComputePipelineState> bucketizationPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}
NSError* error = nil;
id<MTLLibrary> bucketizationLib = compileBucketizationOpsLibrary(device);
id<MTLFunction> bucketizationFunc =
[bucketizationLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(bucketizationFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:bucketizationFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
psoCache[kernel] = pso;
return pso;
}
static void searchsorted_mps_contiguous(Tensor& result,
const Tensor& input,
@ -213,14 +250,15 @@ static void searchsorted_mps_contiguous(Tensor& result,
int64_t right_i64 = right;
int64_t is_1d_boundaries = boundaries.dim() == 1;
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
const std::string kernel = "searchsorted_" + scalarToMetalTypeString(input) + "_" +
scalarToMetalTypeString(result) + (sorter.defined() ? "_sorter" : "");
id<MTLComputePipelineState> bucketizationPSO = lib.getPipelineStateForFunc(kernel);
const std::string kernel = "searchsorted_" + scalarToMetalTypeString(input.scalar_type()) + "_" +
scalarToMetalTypeString(result.scalar_type()) + (sorter.defined() ? "_sorter" : "");
id<MTLComputePipelineState> bucketizationPSO = mps::bucketizationPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(bucketizationPSO, kernel, {input, boundaries, sorter});

View File

@ -21,7 +21,7 @@ static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
}
Tensor output = self;
bool needsCopyToOutput = false;
if (needsGather(self)) {
if (!self.is_contiguous() || self.storage_offset()) {
output = at::empty(self.sizes(), self.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
needsCopyToOutput = true;
}

View File

@ -8,9 +8,7 @@
namespace at::native {
namespace {
using namespace mps;
static MetalShaderLibrary lib(R"CROSS_METAL(
static const char* METAL_CROSS = R"CROSS_METAL(
#include <metal_array>
#include <metal_stdlib>
@ -78,7 +76,44 @@ REGISTER_CROSS_OP(char);
REGISTER_CROSS_OP(uchar);
REGISTER_CROSS_OP(bool);
)CROSS_METAL");
)CROSS_METAL";
using namespace mps;
static id<MTLLibrary> compileCrossOpLibrary(id<MTLDevice> device) {
static id<MTLLibrary> crossLibrary = nil;
if (crossLibrary) {
return crossLibrary;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
crossLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_CROSS encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(crossLibrary, "Failed to create metal cross library, error: ", [[error description] UTF8String]);
return crossLibrary;
}
static id<MTLComputePipelineState> crossPipelineState(id<MTLDevice> device, ScalarType scalar_type) {
std::string kernel = "cross_" + scalarToMetalTypeString(scalar_type);
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}
NSError* error = nil;
id<MTLLibrary> crossLib = compileCrossOpLibrary(device);
id<MTLFunction> crossFunc = [crossLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(crossFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:crossFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
psoCache[kernel] = pso;
return pso;
}
void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other, int64_t dim) {
TORCH_CHECK(input.dtype() != at::kDouble, "float64 is not supported on MPS");
@ -104,7 +139,7 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto kernelDataOffsets = generateKernelDataOffsets(computeEncoder, iter);
auto crossPSO = lib.getPipelineStateForFunc("cross_" + scalarToMetalTypeString(out));
id<MTLComputePipelineState> crossPSO = crossPipelineState(device, out.scalar_type());
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(crossPSO, "cross", {input, other});

View File

@ -24,7 +24,7 @@ namespace mps {
* See note [3-Clause BSD License for the Cephes Math Library].
*/
static MetalShaderLibrary lib(R"METAL(
static const char* GAMMA_OPS_TEMPLATE = R"METAL(
#include <metal_stdlib>
using namespace metal;
@ -388,11 +388,45 @@ kernel void polygamma(device {0} *input [[buffer(0)]],
output[id] = sgn * Gamma(n + 1) * calc_zeta(n + 1, x);
}}
)METAL",
2);
)METAL";
static id<MTLComputePipelineState> getCPLState(const Tensor& t1, const Tensor& t2, const std::string& fname) {
return lib.getPipelineStateForFunc(fname, {scalarToMetalTypeString(t1), scalarToMetalTypeString(t2)});
static id<MTLLibrary> compileGammaOpsLibrary(id<MTLDevice> device, const std::string& t1, const std::string& t2) {
auto key = t1 + t2;
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
auto it = libMap.find(key);
if (it != libMap.end()) {
return it->second;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(GAMMA_OPS_TEMPLATE, t1, t2).c_str()]
options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
libMap[key] = rc;
return rc;
}
static id<MTLComputePipelineState> getCPLState(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
const std::string& fname) {
auto key = t1 + t2 + fname;
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
auto it = cplMap.find(key);
if (it != cplMap.end()) {
return it->second;
}
NSError* error = nil;
auto library = compileGammaOpsLibrary(device, t1, t2);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func != nil, "Can't get function ", fname);
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
TORCH_CHECK(
rc != nil && error == nil, "Failed to construct pipeline state: ", [[error localizedDescription] UTF8String]);
cplMap[key] = rc;
return rc;
}
} // namespace mps
@ -407,15 +441,19 @@ TORCH_IMPL_FUNC(lgamma_out_mps)(const Tensor& self, const Tensor& output_) {
return;
}
if (mps::needsGather(output_)) {
if (!self.is_contiguous()) {
output = output.contiguous();
needs_output_copy = true;
}
using namespace mps;
std::string input_type = scalarToMetalTypeString(self.scalar_type());
std::string output_type = scalarToMetalTypeString(output.scalar_type());
@autoreleasepool {
id<MTLComputePipelineState> cplState = getCPLState(self, output, "lgamma");
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputePipelineState> cplState = getCPLState(device, input_type, output_type, "lgamma");
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@ -447,15 +485,19 @@ TORCH_IMPL_FUNC(digamma_out_mps)(const Tensor& self, const Tensor& output_) {
return;
}
if (mps::needsGather(output_)) {
if (!self.is_contiguous()) {
output = output.contiguous();
needs_output_copy = true;
}
using namespace mps;
std::string input_type = scalarToMetalTypeString(self.scalar_type());
std::string output_type = scalarToMetalTypeString(output.scalar_type());
@autoreleasepool {
id<MTLComputePipelineState> cplState = getCPLState(self, output, "digamma");
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputePipelineState> cplState = getCPLState(device, input_type, output_type, "digamma");
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {
@ -488,13 +530,15 @@ TORCH_IMPL_FUNC(polygamma_out_mps)(const int64_t order, const Tensor& self, cons
return;
}
if (mps::needsGather(output_)) {
if (!self.is_contiguous()) {
output = output.contiguous();
needs_output_copy = true;
}
using namespace mps;
std::string input_type = scalarToMetalTypeString(self.scalar_type());
std::string output_type = scalarToMetalTypeString(output.scalar_type());
std::string func_name;
if (order == 0) {
@ -506,7 +550,9 @@ TORCH_IMPL_FUNC(polygamma_out_mps)(const int64_t order, const Tensor& self, cons
}
@autoreleasepool {
id<MTLComputePipelineState> cplState = getCPLState(self, output, func_name);
id<MTLDevice> device = MPSDevice::getInstance()->device();
id<MTLComputePipelineState> cplState = getCPLState(device, input_type, output_type, func_name);
MPSStream* mpsStream = getCurrentMPSStream();
dispatch_sync(mpsStream->queue(), ^() {

View File

@ -21,7 +21,7 @@ enum BIN_SELECTION_ALGORITHM {
BINARY_SEARCH,
};
static MetalShaderLibrary lib(R"HISTOGRAM_METAL(
static const char* METAL_HISTOGRAM = R"HISTOGRAM_METAL(
#include <metal_stdlib>
using namespace metal;
@ -157,7 +157,42 @@ kernel void kernel_index_offset(constant uint * strides [[buffer
data_offsets[thread_index] += remainder * strides[reversed_dim];
}
}
)HISTOGRAM_METAL");
)HISTOGRAM_METAL";
static id<MTLLibrary> compileHistogramOpLibrary(id<MTLDevice> device) {
static id<MTLLibrary> histogramLibrary = nil;
if (histogramLibrary) {
return histogramLibrary;
}
NSError* error = nil;
MTLCompileOptions* options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion:MTLLanguageVersion2_3];
histogramLibrary = [device newLibraryWithSource:[NSString stringWithCString:METAL_HISTOGRAM
encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(histogramLibrary, "Failed to create metal histogram library, error: ", [[error description] UTF8String]);
return histogramLibrary;
}
static id<MTLComputePipelineState> histogramPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}
NSError* error = nil;
id<MTLLibrary> crossLib = compileHistogramOpLibrary(device);
id<MTLFunction> crossFunc = [crossLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(crossFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:crossFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
psoCache[kernel] = pso;
return pso;
}
template <typename input_t, BIN_SELECTION_ALGORITHM algorithm>
void histogramdd_kernel_impl(Tensor& hist_output,
@ -244,7 +279,7 @@ void histogramdd_kernel_impl(Tensor& hist_output,
id<MTLBuffer> stridedIndicesBuffer = [[device newBufferWithLength:stridedIndicesNumThreads * sizeof(uint)
options:0] autorelease];
id<MTLComputePipelineState> stridedIndicesPSO = lib.getPipelineStateForFunc("kernel_index_offset");
id<MTLComputePipelineState> stridedIndicesPSO = histogramPipelineState(device, "kernel_index_offset");
[computeEncoder setComputePipelineState:stridedIndicesPSO];
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim atIndex:0];
@ -254,8 +289,8 @@ void histogramdd_kernel_impl(Tensor& hist_output,
mtl_dispatch1DJob(computeEncoder, stridedIndicesPSO, stridedIndicesNumThreads);
const std::string kernel = "histogramdd_" + scalarToMetalTypeString(input);
id<MTLComputePipelineState> histogramPSO = lib.getPipelineStateForFunc(kernel);
const std::string kernel = "histogramdd_" + scalarToMetalTypeString(input.scalar_type());
id<MTLComputePipelineState> histogramPSO = histogramPipelineState(device, kernel);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(histogramPSO, "histogram", allTensorsList);

View File

@ -241,20 +241,14 @@ static void index_put_kernel_mps(TensorIterator& iter,
} // namespace mps
static Tensor nonzero_fallback(const Tensor& self) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performance implications.");
return at::nonzero(self.to("cpu")).clone().to("mps");
}
Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performance implications.");
Tensor out_fallback = nonzero_fallback(self);
at::native::resize_output(out_, out_fallback.sizes());
out_.copy_(out_fallback.to("mps"));
return out_;
} else if (self.is_complex()) {
TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ",
"Falling back on CPU. This may have performance implications.");
if (!is_macos_13_or_newer()) {
Tensor out_fallback = nonzero_fallback(self);
at::native::resize_output(out_, out_fallback.sizes());
out_.copy_(out_fallback.to("mps"));
@ -288,6 +282,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
MPSGraphTensor* scatterDataTensor_ = nil;
};
dispatch_sync(stream->queue(), ^() {
@ -299,26 +294,99 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
return out_;
}
bool contiguous_output = !needsGather(out_);
bool contiguous_output = out_.is_contiguous();
Tensor out = out_;
if (!contiguous_output) {
out = at::empty(out_.sizes(), out_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
}
int64_t _apparentInputShape = 1;
for (auto dim : self.sizes()) {
_apparentInputShape *= dim;
}
MPSShape* apparentOutputShape = @[ @(total_nonzero * nDim) ];
MPSShape* apparentInputShape = @[ @(_apparentInputShape) ];
// Pseudocode:
//
// inputTensor = [1, 0, 0, 3]
// inputNonZero = [1, 0, 0, 1]
// indices = [1, 1, 1, 2]
// maskedIndices = [0, -1, -1, 1]
// coordinates = [0, 1, 2, 3]
// scatterResult = [0, 3]
@autoreleasepool {
string key = "nonzero_out_mps" + getTensorsStringKey(self);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
MPSDataType inputDataType = getMPSDataType(self);
MPSShape* inputShape = getMPSShape(self);
MPSGraphTensor* outputTensor = [mpsGraph nonZeroIndicesOfTensor:inputTensor name:nil];
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
MPSGraphTensor* scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType];
MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32];
MPSGraphTensor* minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32];
MPSGraphTensor* inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
secondaryTensor:zeroTensor
name:nil];
MPSGraphTensor* maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
toType:MPSDataTypeInt32
name:@"castToInt32"];
MPSGraphTensor* indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor axis:0 name:nil];
MPSGraphTensor* indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor
secondaryTensor:oneTensor
name:nil];
MPSGraphTensor* maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor
truePredicateTensor:indicesMinusOneTensor
falsePredicateTensor:minusMaxDimTensor
name:nil];
MPSGraphTensor* coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0
withShape:inputShape
name:nil]
withShape:@[ @-1 ]
name:nil];
if (nDim > 1) {
NSMutableArray<MPSGraphTensor*>* maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
NSMutableArray<MPSGraphTensor*>* coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
MPSGraphTensor* constantRankTensor = [mpsGraph constantWithScalar:nDim dataType:MPSDataTypeInt32];
maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor
secondaryTensor:constantRankTensor
name:nil];
coordinatesTensorArray[0] = coordinatesTensor;
for (int i = 1; i < nDim; i++) {
maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1]
secondaryTensor:oneTensor
name:nil];
coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i
withShape:inputShape
name:nil]
withShape:@[ @-1 ]
name:nil];
}
maskedIndicesTensor = [mpsGraph concatTensors:maskedIndicesTensorArray dimension:0 interleave:YES name:nil];
coordinatesTensor = [mpsGraph concatTensors:coordinatesTensorArray dimension:0 interleave:YES name:nil];
}
MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
updatesTensor:coordinatesTensor
indicesTensor:maskedIndicesTensor
axis:0
mode:MPSGraphScatterModeSet
name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->scatterDataTensor_ = scatterDataTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, apparentOutputShape);
Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, out, apparentOutputShape);
auto feeds = dictionaryFromPlaceholders(selfPlaceholder, scatterPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
@ -330,13 +398,7 @@ Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
}
Tensor nonzero_mps(const Tensor& self) {
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performance implications.");
return nonzero_fallback(self);
} else if (self.is_complex()) {
TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ",
"Falling back on CPU. This may have performance implications.");
if (!is_macos_13_or_newer()) {
return nonzero_fallback(self);
}

Some files were not shown because too many files have changed in this diff Show More