mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 09:04:53 +08:00
Compare commits
199 Commits
export-D79
...
codex-test
| Author | SHA1 | Date | |
|---|---|---|---|
| bc67bce2e5 | |||
| 79eca4677b | |||
| 2855688a1d | |||
| 2231c3ca3a | |||
| c03a734ba1 | |||
| 98316e5896 | |||
| 23cf241039 | |||
| e7feedf6a9 | |||
| dad2a05bec | |||
| 0495cab545 | |||
| abfe403981 | |||
| 1690c0c3a0 | |||
| e9d27aa8fd | |||
| 2457e62c90 | |||
| d0fccbc99c | |||
| 3461988a4b | |||
| 9764981116 | |||
| 704594eb23 | |||
| bfc27cf468 | |||
| 311f74089a | |||
| 14c7358c64 | |||
| 8ce81bcee1 | |||
| 4604f0482c | |||
| 15f1173e5d | |||
| e16c48ae97 | |||
| f7a66da5f9 | |||
| 3eb3da9b4b | |||
| 3ddfd46bd2 | |||
| 6a82da392e | |||
| 22bedc429f | |||
| 49abc0e3f8 | |||
| 1052604acd | |||
| fe8984a9f4 | |||
| 74a754aae9 | |||
| b1ec088113 | |||
| fb35a9ea4a | |||
| 8034b2a732 | |||
| 64cc6f06b1 | |||
| 410812763b | |||
| bdb07a2bc5 | |||
| 8085edc8f9 | |||
| 882d50c5bf | |||
| b52a4d0821 | |||
| a45a840926 | |||
| 9b953bb3fb | |||
| eb25a95a6e | |||
| 9884d0351e | |||
| d7c83972d5 | |||
| e06b110f73 | |||
| 0ba09a6d34 | |||
| aeb5321b63 | |||
| 625108ede2 | |||
| 09e5a93fcb | |||
| 908c5cc4c0 | |||
| c1145852a5 | |||
| ae1a706444 | |||
| 56d19a5ced | |||
| b6c53383fe | |||
| 4fd5fabee9 | |||
| bbc0df1094 | |||
| 33ec6e3e9a | |||
| efc4b460b3 | |||
| 1ca8388442 | |||
| b69497351d | |||
| 482f069c41 | |||
| 85d931f29e | |||
| 8a2f53c523 | |||
| b59b61a099 | |||
| 57ab39f7e4 | |||
| 182975e01a | |||
| 9f8cfe7476 | |||
| e273ff028a | |||
| 5e0fc2c9a9 | |||
| bc4b04e058 | |||
| 6b414f56a4 | |||
| fb8f32ef52 | |||
| 7ba996bbaa | |||
| ddbdcdc710 | |||
| 19f1f9960d | |||
| fd6655a0f5 | |||
| a7f3bdf550 | |||
| 510e8b4ae0 | |||
| 83ba3f1101 | |||
| 1fad16aacb | |||
| 444e2381d0 | |||
| 6085bf7565 | |||
| 8201dbf4bc | |||
| 26d045bb60 | |||
| 356ac3103a | |||
| d4109a0f99 | |||
| 7ea789ccfb | |||
| 7e8197e34d | |||
| 50eac811a6 | |||
| 4e0f179d0b | |||
| 36e59d9b12 | |||
| fc340d0ca3 | |||
| 53e47af0f7 | |||
| 66ad881fc7 | |||
| 1d3eef27ac | |||
| dd95900cec | |||
| 1cdd665526 | |||
| 7cb2dcd2dd | |||
| e5a81aa7ba | |||
| 3e2aa4b0e3 | |||
| 6646461764 | |||
| f74da2a136 | |||
| d35b27dde5 | |||
| a9dc1566d4 | |||
| 33a1996714 | |||
| ee62177c19 | |||
| 64cbaa876c | |||
| 4516c59f5f | |||
| 8bc843a9ec | |||
| e39a62c70d | |||
| 978e3a9142 | |||
| e2a5c42e7e | |||
| 5116c49b52 | |||
| fecdebe385 | |||
| e136a9175b | |||
| 9a680e14b7 | |||
| 805a102beb | |||
| 6e8d705a22 | |||
| 9c18901bfd | |||
| a29ed5e1ac | |||
| d2792f51b2 | |||
| be71000ff5 | |||
| 3f86076775 | |||
| 1616777cd2 | |||
| 38895c0ac2 | |||
| 310f901a71 | |||
| e11b1cd97e | |||
| b599d91738 | |||
| fd6a6658c3 | |||
| 04973496a8 | |||
| 1548b011ea | |||
| e57a92734d | |||
| 79ff3b320b | |||
| 426f249f20 | |||
| d33a484763 | |||
| a81ffbc5f5 | |||
| 465fe4d9f7 | |||
| 9477af1063 | |||
| dcc36e38bb | |||
| efd78584a8 | |||
| 135762ea20 | |||
| e2ee9cfaa2 | |||
| 06d28de17a | |||
| df9720b8b5 | |||
| 85e74d5ace | |||
| 0450f05658 | |||
| 595a65f5c2 | |||
| 8c6c2e40eb | |||
| 32840d19f9 | |||
| 2040f00112 | |||
| c137f9da0b | |||
| 5e8b95605f | |||
| 8ea86a6e31 | |||
| acad808545 | |||
| c687446374 | |||
| dd22ba09b4 | |||
| c0e0126399 | |||
| e4b123b5e4 | |||
| 5711a8f069 | |||
| b4b71d011e | |||
| 52376b9b6f | |||
| 1371a98b0e | |||
| 2a286cbdf4 | |||
| 7c37b8e1e0 | |||
| ee2649219c | |||
| b0b3e6e48b | |||
| 3967dbedf4 | |||
| 4396b15aa7 | |||
| bb6766053b | |||
| a4fc051c9a | |||
| 5cc6a0abc1 | |||
| 90f13f3b2a | |||
| cb9b74872b | |||
| c964204829 | |||
| 2ac45c2752 | |||
| 83e2ea8135 | |||
| d994027a41 | |||
| cb4f41e125 | |||
| 690fc9cf88 | |||
| eb853e222b | |||
| 06395276e4 | |||
| 8becf646ef | |||
| fa68216ca1 | |||
| 25ef3d315d | |||
| 7e00f2ec9d | |||
| 490cb3f1a4 | |||
| b95cf5c91d | |||
| 5e2ef2a465 | |||
| 9f753f8c0d | |||
| db437690d1 | |||
| 669009bcd1 | |||
| e4e2701429 | |||
| 64cc649275 | |||
| b1fb552974 | |||
| bb62e1f769 |
@ -144,16 +144,6 @@ case "$tag" in
|
||||
TRITON=yes
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9)
|
||||
CUDA_VERSION=12.6.3
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=9
|
||||
VISION=yes
|
||||
KATEX=yes
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
@ -164,39 +154,6 @@ case "$tag" in
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
GCC_VERSION=9
|
||||
VISION=yes
|
||||
KATEX=yes
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.12
|
||||
GCC_VERSION=9
|
||||
VISION=yes
|
||||
KATEX=yes
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks)
|
||||
CUDA_VERSION=12.6
|
||||
ANACONDA_PYTHON_VERSION=3.13
|
||||
GCC_VERSION=9
|
||||
VISION=yes
|
||||
KATEX=yes
|
||||
UCX_COMMIT=${_UCX_COMMIT}
|
||||
UCC_COMMIT=${_UCC_COMMIT}
|
||||
TRITON=yes
|
||||
INDUCTOR_BENCHMARKS=yes
|
||||
;;
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9)
|
||||
CUDA_VERSION=12.8.1
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
@ -219,18 +176,6 @@ case "$tag" in
|
||||
VISION=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.11-clang12)
|
||||
ANACONDA_PYTHON_VERSION=3.11
|
||||
CLANG_VERSION=12
|
||||
VISION=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-py3.9-gcc9)
|
||||
ANACONDA_PYTHON_VERSION=3.9
|
||||
GCC_VERSION=9
|
||||
VISION=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3)
|
||||
if [[ $tag =~ "jammy" ]]; then
|
||||
ANACONDA_PYTHON_VERSION=3.10
|
||||
|
||||
@ -1 +1 @@
|
||||
11ec6354315768a85da41032535e3b7b99c5f706
|
||||
f7888497a1eb9e98d4c07537f0d0bcfe180d1363
|
||||
|
||||
@ -68,8 +68,8 @@ function install_nvshmem {
|
||||
# download, unpack, install
|
||||
wget -q "${url}"
|
||||
tar xf "${filename}.tar.gz"
|
||||
cp -a "libnvshmem/include/"* /usr/local/include/
|
||||
cp -a "libnvshmem/lib/"* /usr/local/lib/
|
||||
cp -a "libnvshmem/include/"* /usr/local/cuda/include/
|
||||
cp -a "libnvshmem/lib/"* /usr/local/cuda/lib64/
|
||||
|
||||
# cleanup
|
||||
cd ..
|
||||
|
||||
@ -15,11 +15,37 @@ function install_timm() {
|
||||
commit=$(get_pinned_commit timm)
|
||||
|
||||
pip_install "git+https://github.com/huggingface/pytorch-image-models@${commit}"
|
||||
# Clean up
|
||||
conda_run pip uninstall -y torch torchvision triton
|
||||
}
|
||||
|
||||
function install_torchbench() {
|
||||
local commit
|
||||
commit=$(get_pinned_commit torchbench)
|
||||
git clone https://github.com/pytorch/benchmark torchbench
|
||||
pushd torchbench
|
||||
git checkout "$commit"
|
||||
|
||||
python install.py --continue_on_fail
|
||||
|
||||
# TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488
|
||||
# is regressing speedup metric. This needs to be investigated further
|
||||
pip install transformers==4.38.1
|
||||
|
||||
echo "Print all dependencies after TorchBench is installed"
|
||||
python -mpip freeze
|
||||
popd
|
||||
|
||||
chown -R jenkins torchbench
|
||||
}
|
||||
|
||||
# Pango is needed for weasyprint which is needed for doctr
|
||||
conda_install pango
|
||||
|
||||
# Stable packages are ok here, just to satisfy TorchBench check
|
||||
pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||
|
||||
install_torchbench
|
||||
install_huggingface
|
||||
install_timm
|
||||
|
||||
# Clean up
|
||||
conda_run pip uninstall -y torch torchvision torchaudio triton
|
||||
|
||||
@ -103,5 +103,5 @@ fi
|
||||
# It depends on torch and triton. We don't want to install
|
||||
# triton and torch from production on Docker CI images
|
||||
if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then
|
||||
pip_install helion==0.0.10 --no-deps
|
||||
pip_install helion --no-deps
|
||||
fi
|
||||
|
||||
@ -361,7 +361,6 @@ pwlf==2.2.1
|
||||
#Pinned versions: 2.2.1
|
||||
#test that import: test_sac_estimator.py
|
||||
|
||||
|
||||
# To build PyTorch itself
|
||||
pyyaml
|
||||
pyzstd
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
sphinx==5.3.0
|
||||
#Description: This is used to generate PyTorch docs
|
||||
#Pinned versions: 5.3.0
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2
|
||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@722b7e6f9ca512fcc526ad07d62b3d28c50bb6cd#egg=pytorch_sphinx_theme2
|
||||
|
||||
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
|
||||
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
|
||||
@ -50,7 +50,7 @@ IPython==8.12.0
|
||||
#Pinned versions: 8.12.0
|
||||
|
||||
myst-nb==0.17.2
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs
|
||||
#Description: This is used to generate PyTorch functorch and torch.compile docs.
|
||||
#Pinned versions: 0.17.2
|
||||
|
||||
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs
|
||||
|
||||
@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface.txt huggingface.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt
|
||||
|
||||
# (optional) Install non-default Ninja version
|
||||
ARG NINJA_VERSION
|
||||
|
||||
@ -98,8 +98,9 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/huggingface.txt huggingface.txt
|
||||
COPY ci_commit_pins/timm.txt timm.txt
|
||||
COPY ci_commit_pins/torchbench.txt torchbench.txt
|
||||
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt
|
||||
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt
|
||||
|
||||
ARG TRITON
|
||||
ARG TRITON_CPU
|
||||
|
||||
@ -194,7 +194,7 @@ ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library
|
||||
ROCBLAS_LIB_DST=lib/rocblas/library
|
||||
ROCBLAS_ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH)
|
||||
ROCBLAS_OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx)
|
||||
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $OTHER_FILES)
|
||||
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $ROCBLAS_OTHER_FILES)
|
||||
|
||||
# hipblaslt library files
|
||||
HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library
|
||||
|
||||
@ -229,7 +229,6 @@ function install_torchrec_and_fbgemm() {
|
||||
|
||||
pip_install tabulate # needed for newer fbgemm
|
||||
pip_install patchelf # needed for rocm fbgemm
|
||||
pushd /tmp
|
||||
|
||||
local wheel_dir=dist/fbgemm_gpu
|
||||
local found_whl=0
|
||||
@ -245,7 +244,7 @@ function install_torchrec_and_fbgemm() {
|
||||
if [ "${found_whl}" == "0" ]; then
|
||||
git clone --recursive https://github.com/pytorch/fbgemm
|
||||
pushd fbgemm/fbgemm_gpu
|
||||
git checkout "${fbgemm_commit}"
|
||||
git checkout "${fbgemm_commit}" --recurse-submodules
|
||||
python setup.py bdist_wheel \
|
||||
--build-variant=rocm \
|
||||
-DHIP_ROOT_DIR="${ROCM_PATH}" \
|
||||
@ -264,7 +263,6 @@ function install_torchrec_and_fbgemm() {
|
||||
done
|
||||
|
||||
rm -rf fbgemm
|
||||
popd
|
||||
else
|
||||
pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec
|
||||
pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu
|
||||
@ -283,30 +281,6 @@ function clone_pytorch_xla() {
|
||||
fi
|
||||
}
|
||||
|
||||
function checkout_install_torchbench() {
|
||||
local commit
|
||||
commit=$(get_pinned_commit torchbench)
|
||||
git clone https://github.com/pytorch/benchmark torchbench
|
||||
pushd torchbench
|
||||
git checkout "$commit"
|
||||
|
||||
if [ "$1" ]; then
|
||||
python install.py --continue_on_fail models "$@"
|
||||
else
|
||||
# Occasionally the installation may fail on one model but it is ok to continue
|
||||
# to install and test other models
|
||||
python install.py --continue_on_fail
|
||||
fi
|
||||
|
||||
# TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488
|
||||
# is regressing speedup metric. This needs to be investigated further
|
||||
pip install transformers==4.38.1
|
||||
|
||||
echo "Print all dependencies after TorchBench is installed"
|
||||
python -mpip freeze
|
||||
popd
|
||||
}
|
||||
|
||||
function install_torchao() {
|
||||
local commit
|
||||
commit=$(get_pinned_commit torchao)
|
||||
|
||||
@ -157,6 +157,29 @@ test_jit_hooks() {
|
||||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
# Shellcheck doesn't like it when you pass no arguments to a function
|
||||
# that can take args. See https://www.shellcheck.net/wiki/SC2120
|
||||
# shellcheck disable=SC2120
|
||||
checkout_install_torchbench() {
|
||||
local commit
|
||||
commit=$(cat .ci/docker/ci_commit_pins/torchbench.txt)
|
||||
git clone https://github.com/pytorch/benchmark torchbench
|
||||
pushd torchbench
|
||||
git checkout "$commit"
|
||||
|
||||
if [ "$1" ]; then
|
||||
python install.py --continue_on_fail models "$@"
|
||||
else
|
||||
# Occasionally the installation may fail on one model but it is ok to continue
|
||||
# to install and test other models
|
||||
python install.py --continue_on_fail
|
||||
fi
|
||||
|
||||
echo "Print all dependencies after TorchBench is installed"
|
||||
python -mpip freeze
|
||||
popd
|
||||
}
|
||||
|
||||
torchbench_setup_macos() {
|
||||
git clone --recursive https://github.com/pytorch/vision torchvision
|
||||
git clone --recursive https://github.com/pytorch/audio torchaudio
|
||||
@ -179,8 +202,6 @@ torchbench_setup_macos() {
|
||||
USE_OPENMP=0 python setup.py develop
|
||||
popd
|
||||
|
||||
# Shellcheck doesn't like it when you pass no arguments to a function that can take args. See https://www.shellcheck.net/wiki/SC2120
|
||||
# shellcheck disable=SC2119,SC2120
|
||||
checkout_install_torchbench
|
||||
}
|
||||
|
||||
|
||||
@ -627,6 +627,8 @@ test_perf_for_dashboard() {
|
||||
device=cuda_a10g
|
||||
elif [[ "${TEST_CONFIG}" == *h100* ]]; then
|
||||
device=cuda_h100
|
||||
elif [[ "${TEST_CONFIG}" == *b200* ]]; then
|
||||
device=cuda_b200
|
||||
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
|
||||
device=rocm
|
||||
fi
|
||||
@ -801,6 +803,16 @@ test_dynamo_benchmark() {
|
||||
if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then
|
||||
test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@"
|
||||
elif [[ "${TEST_CONFIG}" == *perf* ]]; then
|
||||
# TODO (huydhn): Just smoke test some sample models
|
||||
if [[ "${TEST_CONFIG}" == *b200* ]]; then
|
||||
if [[ "${suite}" == "huggingface" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="DistillGPT2"
|
||||
elif [[ "${suite}" == "timm_models" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="inception_v3"
|
||||
elif [[ "${suite}" == "torchbench" ]]; then
|
||||
export TORCHBENCH_ONLY_MODELS="hf_Bert"
|
||||
fi
|
||||
fi
|
||||
test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
|
||||
else
|
||||
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
|
||||
@ -1672,13 +1684,11 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then
|
||||
elif [[ "${TEST_CONFIG}" == cachebench ]]; then
|
||||
install_torchaudio
|
||||
install_torchvision
|
||||
checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco
|
||||
PYTHONPATH=$(pwd)/torchbench test_cachebench
|
||||
PYTHONPATH=/torchbench test_cachebench
|
||||
elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then
|
||||
install_torchaudio
|
||||
install_torchvision
|
||||
checkout_install_torchbench nanogpt
|
||||
PYTHONPATH=$(pwd)/torchbench test_verify_cachebench
|
||||
PYTHONPATH=/torchbench test_verify_cachebench
|
||||
elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
|
||||
install_torchaudio
|
||||
install_torchvision
|
||||
@ -1687,28 +1697,22 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
|
||||
# https://github.com/opencv/opencv-python/issues/885
|
||||
pip_install opencv-python==4.8.0.74
|
||||
if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then
|
||||
checkout_install_torchbench hf_Bert hf_Albert timm_vision_transformer
|
||||
PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf
|
||||
PYTHONPATH=/torchbench test_inductor_torchbench_smoketest_perf
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then
|
||||
checkout_install_torchbench timm_vision_transformer phlippe_densenet basic_gnn_edgecnn \
|
||||
llama_v2_7b_16h resnet50 timm_efficientnet mobilenet_v3_large timm_resnest \
|
||||
functorch_maml_omniglot yolov3 mobilenet_v2 resnext50_32x4d densenet121 mnasnet1_0
|
||||
PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_cpu_smoketest_perf
|
||||
PYTHONPATH=/torchbench test_inductor_torchbench_cpu_smoketest_perf
|
||||
elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then
|
||||
checkout_install_torchbench
|
||||
TORCHBENCHPATH=$(pwd)/torchbench test_torchbench_gcp_smoketest
|
||||
TORCHBENCHPATH=/torchbench test_torchbench_gcp_smoketest
|
||||
else
|
||||
checkout_install_torchbench
|
||||
# Do this after checkout_install_torchbench to ensure we clobber any
|
||||
# nightlies that torchbench may pull in
|
||||
if [[ "${TEST_CONFIG}" != *cpu* ]]; then
|
||||
install_torchrec_and_fbgemm
|
||||
fi
|
||||
PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id"
|
||||
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"
|
||||
fi
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then
|
||||
install_torchvision
|
||||
PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
|
||||
PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
|
||||
if [[ "$SHARD_NUMBER" -eq "1" ]]; then
|
||||
test_inductor_aoti
|
||||
fi
|
||||
|
||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
||||
bf305f538005f2e900f8850ed57146024a8bc559
|
||||
6fbc710b617f79b992ef2ebc7f95e818aa390293
|
||||
|
||||
2
.github/ci_commit_pins/vllm.txt
vendored
2
.github/ci_commit_pins/vllm.txt
vendored
@ -1 +1 @@
|
||||
ca9e2be3ed6320b51f52f536595cd24e254f8bb2
|
||||
6a39ba85fe0f2fff9494b5eccea717c93510c230
|
||||
|
||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
29ae4c76c026185f417a25e841d2cd5e65f087a3
|
||||
b6a5b82b9948b610fa4c304d0d869c82b8f17db1
|
||||
|
||||
4
.github/merge_rules.yaml
vendored
4
.github/merge_rules.yaml
vendored
@ -488,6 +488,10 @@
|
||||
- torch/_dynamo/**
|
||||
- torch/csrc/dynamo/**
|
||||
- test/dynamo/**
|
||||
- test/dynamo_expected_failures/**
|
||||
- test/dynamo_skips/**
|
||||
- test/inductor_expected_failures/**
|
||||
- test/inductor_skips/**
|
||||
approved_by:
|
||||
- guilhermeleobas
|
||||
mandatory_checks_name:
|
||||
|
||||
@ -193,7 +193,7 @@ LIBTORCH_CONTAINER_IMAGES: dict[str, str] = {
|
||||
"cpu": "libtorch-cxx11-builder:cpu",
|
||||
}
|
||||
|
||||
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t"]
|
||||
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"]
|
||||
|
||||
|
||||
def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
|
||||
@ -315,6 +315,11 @@ def generate_wheels_matrix(
|
||||
# TODO: Enable python 3.13t on cpu-s390x
|
||||
if gpu_arch_type == "cpu-s390x" and python_version == "3.13t":
|
||||
continue
|
||||
# TODO: Enable python 3.14 on non linux OSes
|
||||
if os != "linux" and (
|
||||
python_version == "3.14" or python_version == "3.14t"
|
||||
):
|
||||
continue
|
||||
|
||||
if use_split_build and (
|
||||
arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux"
|
||||
|
||||
20
.github/workflows/_linux-test.yml
vendored
20
.github/workflows/_linux-test.yml
vendored
@ -96,7 +96,7 @@ jobs:
|
||||
steps:
|
||||
- name: Setup SSH (Click me for login details)
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
if: ${{ matrix.runner != 'B200' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
if: ${{ !contains(matrix.runner, 'b200') && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
instructions: |
|
||||
@ -109,7 +109,7 @@ jobs:
|
||||
no-sudo: true
|
||||
|
||||
- name: Setup Python
|
||||
if: matrix.runner == 'B200'
|
||||
if: contains(matrix.runner, 'b200')
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
@ -117,7 +117,7 @@ jobs:
|
||||
|
||||
- name: Setup Linux
|
||||
uses: ./.github/actions/setup-linux
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && matrix.runner != 'B200'
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && !contains(matrix.runner, 'b200')
|
||||
|
||||
- name: configure aws credentials
|
||||
if: ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
@ -128,7 +128,7 @@ jobs:
|
||||
aws-region: us-east-1
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
if: ${{ inputs.aws-role-to-assume != '' && matrix.runner == 'B200' }}
|
||||
if: ${{ inputs.aws-role-to-assume != '' && contains(matrix.runner, 'b200') }}
|
||||
id: login-ecr
|
||||
continue-on-error: true
|
||||
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
|
||||
@ -166,17 +166,17 @@ jobs:
|
||||
uses: pytorch/test-infra/.github/actions/setup-nvidia@main
|
||||
with:
|
||||
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }}
|
||||
|
||||
- name: Setup GPU_FLAG for docker run
|
||||
id: setup-gpu-flag
|
||||
run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}"
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || matrix.runner == 'B200') }}
|
||||
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }}
|
||||
|
||||
- name: Setup SCCACHE_SERVER_PORT environment for docker run when on container
|
||||
id: setup-sscache-port-flag
|
||||
run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}"
|
||||
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && matrix.runner != 'B200' }}
|
||||
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && !contains(matrix.runner, 'b200') }}
|
||||
|
||||
- name: Lock NVIDIA A100 40GB Frequency
|
||||
run: |
|
||||
@ -277,8 +277,8 @@ jobs:
|
||||
NO_TD: ${{ steps.keep-going.outputs.ci-no-td }}
|
||||
TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }}
|
||||
# Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs
|
||||
SCCACHE_BUCKET: ${{ matrix.runner != 'B200' && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ matrix.runner != 'B200' && 'us-east-1' || '' }}
|
||||
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
|
||||
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
|
||||
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
|
||||
DOCKER_IMAGE: ${{ inputs.docker-image }}
|
||||
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
|
||||
@ -403,7 +403,7 @@ jobs:
|
||||
job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }}
|
||||
|
||||
- name: Authenticate with AWS
|
||||
if: ${{ matrix.runner == 'B200' }}
|
||||
if: ${{ contains(matrix.runner, 'b200') }}
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
|
||||
|
||||
3
.github/workflows/check-labels.yml
vendored
3
.github/workflows/check-labels.yml
vendored
@ -34,7 +34,8 @@ jobs:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
name: Check labels
|
||||
if: github.repository_owner == 'pytorch'
|
||||
# Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved
|
||||
if: github.repository_owner == 'pytorch' && false
|
||||
runs-on: linux.24_04.4x
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
|
||||
@ -7,7 +7,8 @@ on:
|
||||
|
||||
jobs:
|
||||
ghstack-mergeability-check:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
# Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved
|
||||
if: github.repository_owner == 'pytorch' && false
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
8
.github/workflows/docker-builds.yml
vendored
8
.github/workflows/docker-builds.yml
vendored
@ -51,17 +51,12 @@ jobs:
|
||||
docker-image-name: [
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm,
|
||||
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
|
||||
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
|
||||
pytorch-linux-jammy-py3.9-clang12,
|
||||
pytorch-linux-jammy-py3.11-clang12,
|
||||
pytorch-linux-jammy-py3.12-clang12,
|
||||
pytorch-linux-jammy-py3.13-clang12,
|
||||
pytorch-linux-jammy-rocm-n-py3,
|
||||
pytorch-linux-noble-rocm-n-py3,
|
||||
@ -76,7 +71,8 @@ jobs:
|
||||
pytorch-linux-jammy-py3-clang12-onnx,
|
||||
pytorch-linux-jammy-linter,
|
||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter,
|
||||
pytorch-linux-jammy-py3-clang12-executorch,
|
||||
# Executorch pin needs update
|
||||
# pytorch-linux-jammy-py3-clang12-executorch,
|
||||
pytorch-linux-jammy-py3.12-triton-cpu
|
||||
]
|
||||
include:
|
||||
|
||||
1226
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
1226
.github/workflows/generated-linux-binary-manywheel-nightly.yml
generated
vendored
File diff suppressed because it is too large
Load Diff
154
.github/workflows/inductor-perf-test-b200.yml
vendored
Normal file
154
.github/workflows/inductor-perf-test-b200.yml
vendored
Normal file
@ -0,0 +1,154 @@
|
||||
name: inductor-perf-b200
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 7 * * 1-6
|
||||
- cron: 0 7 * * 0
|
||||
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
|
||||
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
training:
|
||||
description: Run training (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
inference:
|
||||
description: Run inference (on by default)?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
default:
|
||||
description: Run inductor_default?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
dynamic:
|
||||
description: Run inductor_dynamic_shapes?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cppwrapper:
|
||||
description: Run inductor_cpp_wrapper?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cudagraphs:
|
||||
description: Run inductor_cudagraphs?
|
||||
required: false
|
||||
type: boolean
|
||||
default: true
|
||||
freezing_cudagraphs:
|
||||
description: Run inductor_cudagraphs with freezing for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
aotinductor:
|
||||
description: Run aot_inductor for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
maxautotune:
|
||||
description: Run inductor_max_autotune?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
benchmark_configs:
|
||||
description: The list of configs used the benchmark
|
||||
required: false
|
||||
type: string
|
||||
default: inductor_huggingface_perf_cuda_b200,inductor_timm_perf_cuda_b200,inductor_torchbench_perf_cuda_b200
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-label-type:
|
||||
name: get-label-type
|
||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
with:
|
||||
triggering_actor: ${{ github.triggering_actor }}
|
||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
opt_out_experiments: lf
|
||||
|
||||
build:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
# Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100
|
||||
# or newer GPUs, so it doesn't benefit much from existing compiler cache
|
||||
# from trunk. Also use a memory-intensive runner here because memory is
|
||||
# usually the bottleneck
|
||||
runner: linux.12xlarge.memory
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
|
||||
cuda-arch-list: '10.0'
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "inductor_huggingface_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
{ config: "inductor_timm_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
{ config: "inductor_torchbench_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
|
||||
]}
|
||||
selected-test-configs: ${{ inputs.benchmark_configs }}
|
||||
build-additional-packages: "vision audio fbgemm torchao"
|
||||
secrets: inherit
|
||||
|
||||
test-periodically:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
if: github.event.schedule == '0 7 * * 1-6'
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
timeout-minutes: 720
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
test-weekly:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
if: github.event.schedule == '0 7 * * 0'
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
timeout-minutes: 1440
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
|
||||
test:
|
||||
name: cuda12.8-py3.10-gcc9-sm100
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs: build
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.build.outputs.test-matrix }}
|
||||
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
|
||||
timeout-minutes: 720
|
||||
disable-monitor: false
|
||||
monitor-log-interval: 15
|
||||
monitor-data-collect-interval: 4
|
||||
secrets: inherit
|
||||
30
.github/workflows/inductor-periodic.yml
vendored
30
.github/workflows/inductor-periodic.yml
vendored
@ -81,21 +81,21 @@ jobs:
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
|
||||
9
.github/workflows/nightly.yml
vendored
9
.github/workflows/nightly.yml
vendored
@ -75,10 +75,11 @@ jobs:
|
||||
repo-owner: pytorch
|
||||
branch: main
|
||||
pin-folder: .github/ci_commit_pins
|
||||
- repo-name: executorch
|
||||
repo-owner: pytorch
|
||||
branch: main
|
||||
pin-folder: .ci/docker/ci_commit_pins
|
||||
# executorch jobs are disabled since it needs some manual work for the hash update
|
||||
# - repo-name: executorch
|
||||
# repo-owner: pytorch
|
||||
# branch: main
|
||||
# pin-folder: .ci/docker/ci_commit_pins
|
||||
- repo-name: triton
|
||||
repo-owner: triton-lang
|
||||
branch: main
|
||||
|
||||
31
.github/workflows/periodic.yml
vendored
31
.github/workflows/periodic.yml
vendored
@ -51,37 +51,6 @@ jobs:
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-jammy-cuda12_4-py3_10-gcc11-sm89-build:
|
||||
name: linux-jammy-cuda12.4-py3.10-gcc11-sm89
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_4-py3_10-gcc11-sm89-test:
|
||||
name: linux-jammy-cuda12.4-py3.10-gcc11-sm89
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_4-py3_10-gcc11-sm89-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_4-py3_10-gcc11-build:
|
||||
name: linux-jammy-cuda12.4-py3.10-gcc11
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
|
||||
43
.github/workflows/pull.yml
vendored
43
.github/workflows/pull.yml
vendored
@ -292,13 +292,14 @@ jobs:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
|
||||
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
@ -402,38 +403,8 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-sm89-build:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-sm89
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||
cuda-arch-list: 8.9
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-cuda12_8-py3_10-gcc11-sm89-test:
|
||||
name: linux-jammy-cuda12.8-py3.10-gcc11-sm89
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-jammy-cuda12_8-py3_10-gcc11-sm89-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89
|
||||
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3-clang12-executorch-build:
|
||||
if: false # Docker build needs pin update
|
||||
name: linux-jammy-py3-clang12-executorch
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
|
||||
4
.github/workflows/torchbench.yml
vendored
4
.github/workflows/torchbench.yml
vendored
@ -10,6 +10,10 @@ concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
get-default-label-prefix:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
|
||||
2
.github/workflows/trunk.yml
vendored
2
.github/workflows/trunk.yml
vendored
@ -205,7 +205,7 @@ jobs:
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-jammy-py3.9-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11
|
||||
docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "verify_cachebench", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
|
||||
|
||||
2
.github/workflows/update-viablestrict.yml
vendored
2
.github/workflows/update-viablestrict.yml
vendored
@ -23,7 +23,7 @@ jobs:
|
||||
with:
|
||||
repository: pytorch/pytorch
|
||||
stable-branch: viable/strict
|
||||
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\"]'
|
||||
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\", \"linux-aarch64\"]'
|
||||
secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }}
|
||||
clickhouse-url: ${{ secrets.CLICKHOUSE_URL }}
|
||||
clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }}
|
||||
|
||||
17
AGENTS.md
17
AGENTS.md
@ -1 +1,18 @@
|
||||
- This is the only AGENTS.md, there are no recursive AGENTS.md
|
||||
- When you are working on a bug, first create a standalone file that
|
||||
reproduces the bug and verify it fails in the expected way. Use this to
|
||||
test if your changes work. Once the change is passing, find an appropriate
|
||||
test file to add the test to and make sure to follow local conventions on
|
||||
the test file.
|
||||
- If you are running the real test suite, DO NOT run the entire test suite.
|
||||
Instead run only a single test case, e.g., 'python test/test_torch.py TestTorch.test_dir'
|
||||
- Do NOT run setup.py, you do not have a working build environment
|
||||
- Do NOT run pre-commit, it is not setup
|
||||
- To run lint, run 'lintrunner -a' (which will autoapply changes). lintrunner
|
||||
ONLY accepts this flag, do not try to run on individual files.
|
||||
- Do NOT attempt to install dependencies, you do not have Internet access
|
||||
- When you are ready to make a PR, do exactly these steps:
|
||||
- git stash -u
|
||||
- git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch
|
||||
- git stash pop
|
||||
- Resolve conflicts if necessary
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
/torch/csrc/autograd/ @albanD @soulitzer
|
||||
/torch/autograd/ @albanD @soulitzer
|
||||
/tools/autograd/ @albanD @soulitzer
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
|
||||
/torch/optim/ @albanD @janeyx99
|
||||
/test/test_public_bindings.py @albanD
|
||||
@ -196,3 +195,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed
|
||||
/torch/utils/_cxx_pytree.py @XuehaiPan
|
||||
/torch/utils/pytree/ @XuehaiPan
|
||||
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
|
||||
|
||||
# Relating to libtorch ABI
|
||||
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
|
||||
/torch/headeronly/ @janeyx99
|
||||
/torch/header_only_apis.txt @janeyx99
|
||||
|
||||
@ -276,7 +276,7 @@ conda install pkg-config libuv
|
||||
pip install mkl-static mkl-include
|
||||
# Add these packages if torch.distributed is needed.
|
||||
# Distributed package support on Windows is a prototype feature and is subject to changes.
|
||||
conda install -c conda-forge libuv=1.39
|
||||
conda install -c conda-forge libuv
|
||||
```
|
||||
|
||||
#### Install PyTorch
|
||||
|
||||
@ -439,6 +439,7 @@ if(USE_ROCM)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
|
||||
_pytorch_rocm_generate_ck_conf()
|
||||
@ -703,21 +704,17 @@ if(USE_MPS)
|
||||
if(CAN_COMPILE_METAL)
|
||||
foreach(SHADER ${native_mps_metal})
|
||||
cmake_path(GET SHADER STEM TGT_STEM)
|
||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air")
|
||||
string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air")
|
||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
|
||||
list(APPEND AIR_BASIC ${TGT_BASIC})
|
||||
list(APPEND AIR_BFLOAT ${TGT_BFLOAT})
|
||||
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0")
|
||||
metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1")
|
||||
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1")
|
||||
endforeach()
|
||||
air_to_metallib(kernels_basic.metallib ${AIR_BASIC})
|
||||
air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT})
|
||||
add_custom_command(
|
||||
COMMAND echo "// $$(date)" > metallib_dummy.cpp
|
||||
DEPENDS kernels_basic.metallib kernels_bfloat.metallib
|
||||
DEPENDS kernels_basic.metallib
|
||||
OUTPUT metallib_dummy.cpp
|
||||
COMMENT "Updating metallibs timestamp")
|
||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp)
|
||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
|
||||
else()
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
|
||||
foreach(SHADER ${native_mps_metal})
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/CachingDeviceAllocator.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
@ -72,6 +73,27 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index);
|
||||
// original device index that was active before the change.
|
||||
TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index);
|
||||
|
||||
TORCH_API inline void emptyCache() {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
at::getDeviceAllocator(device_type)->emptyCache();
|
||||
}
|
||||
|
||||
TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats(
|
||||
c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
return at::getDeviceAllocator(device_type)->getDeviceStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index);
|
||||
}
|
||||
|
||||
TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
||||
const auto device_type = getAccelerator(true).value();
|
||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||
}
|
||||
|
||||
} // namespace at::accelerator
|
||||
|
||||
namespace at {
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
#include <ATen/cuda/CUDAGraph.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/cuda/CUDAGraphsC10Utils.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
@ -24,6 +24,29 @@ static void _assert_match(const O& original, const C& compared, const std::strin
|
||||
}
|
||||
}
|
||||
|
||||
template<>
|
||||
void _assert_match<c10::Device, std::optional<c10::Device>>(
|
||||
const c10::Device& original,
|
||||
const std::optional<c10::Device>& compared,
|
||||
const std::string& name) {
|
||||
if (compared) {
|
||||
const c10::Device& expected = compared.value();
|
||||
if (original.type() != expected.type()) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
// If the expected device doesn't have an index (e.g., just "cuda"),
|
||||
// or if both devices have the same index, consider them equal
|
||||
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
|
||||
std::stringstream msg;
|
||||
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
|
||||
_assert_match(tensor.sym_sizes(), sizes, "sizes");
|
||||
_assert_match(tensor.sym_strides(), strides, "strides");
|
||||
|
||||
@ -367,27 +367,27 @@ void int8pack_mm_kernel_(
|
||||
auto* C_data = C.data_ptr<T>();
|
||||
const auto* S_data = scales.const_data_ptr<T>();
|
||||
|
||||
int M = A.size(0);
|
||||
int N = B.size(0);
|
||||
int K = A.size(1);
|
||||
int lda = A.stride(0);
|
||||
constexpr int BLOCK_M = 4;
|
||||
constexpr int BLOCK_N = 4;
|
||||
int64_t M = A.size(0);
|
||||
int64_t N = B.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 4;
|
||||
|
||||
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
const int64_t MB = (M + BLOCK_M - 1) / BLOCK_M;
|
||||
const int64_t NB = (N + BLOCK_N - 1) / BLOCK_N;
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
|
||||
int mb{0}, nb{0};
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
(void)i;
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
const auto* A_ptr = A_data + mb_start * lda;
|
||||
const auto* B_ptr = B_data + nb_start * K;
|
||||
|
||||
@ -526,7 +526,7 @@ namespace {
|
||||
|
||||
|
||||
// we are dealing with packed tensor here. max index is the same as numel.
|
||||
// TODO: to really support input tensor large enought to go beyond int32,
|
||||
// TODO: to really support input tensor large enough to go beyond int32,
|
||||
// we will need to restrict out shared memory usage and adjust the launch
|
||||
// config;
|
||||
AT_ASSERT(input_.numel() < std::numeric_limits<int32_t>::max());
|
||||
@ -681,7 +681,7 @@ namespace {
|
||||
const dim3 grid(grid_x, grid_y, grid_z);
|
||||
|
||||
// we are dealing with packed tensor here. max index is the same as numel.
|
||||
// TODO: to really support input tensor large enought to go beyond int32,
|
||||
// TODO: to really support input tensor large enough to go beyond int32,
|
||||
// we will need to restrict out shared memory usage and adjust the launch
|
||||
// config;
|
||||
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());
|
||||
|
||||
@ -1634,6 +1634,9 @@ bool use_fast_accum) {
|
||||
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
TORCH_CHECK(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
@ -1716,6 +1719,9 @@ std::optional<c10::ScalarType> out_dtype) {
|
||||
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
|
||||
// check that the strides are valid, the fn will throw an error if not
|
||||
check_valid_strides_and_return_transposed(mat_a);
|
||||
|
||||
@ -223,7 +223,7 @@ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bo
|
||||
class CuFFTConfig {
|
||||
public:
|
||||
|
||||
// Only move semantics is enought for this class. Although we already use
|
||||
// Only move semantics is enough for this class. Although we already use
|
||||
// unique_ptr for the plan, still remove copy constructor and assignment op so
|
||||
// we don't accidentally copy and take perf hit.
|
||||
CuFFTConfig(const CuFFTConfig&) = delete;
|
||||
|
||||
@ -241,6 +241,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
|
||||
Strides tensor_StrideA = make_strides(mat_a.strides());
|
||||
Strides tensor_StrideB = make_strides(mat_b.strides());
|
||||
Strides tensor_StrideOutput = make_strides(out.strides());
|
||||
Strides tensor_ShapeA = make_strides(mat_a.sizes());
|
||||
Strides tensor_ShapeB = make_strides(mat_b.sizes());
|
||||
|
||||
at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>(
|
||||
reinterpret_cast<DtypeA*>(mat_a.data_ptr()),
|
||||
@ -264,6 +266,8 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
|
||||
tensor_StrideA,
|
||||
tensor_StrideB,
|
||||
tensor_StrideOutput,
|
||||
tensor_ShapeA,
|
||||
tensor_ShapeB,
|
||||
0,
|
||||
0,
|
||||
a_row_major,
|
||||
|
||||
@ -38,18 +38,20 @@ __global__ void prepare_grouped_gemm_data(
|
||||
Strides tensor_StrideA,
|
||||
Strides tensor_StrideB,
|
||||
Strides tensor_StrideOutput,
|
||||
Strides tensor_ShapeA,
|
||||
Strides tensor_ShapeB,
|
||||
int64_t a_scale_stride,
|
||||
int64_t b_scale_stride,
|
||||
bool a_row_major = true,
|
||||
bool b_row_major = false) {
|
||||
int32_t tid = threadIdx.x;
|
||||
int32_t delta = 0;
|
||||
int32_t offset = 0;
|
||||
if (offs != nullptr) {
|
||||
int32_t start = tid == 0 ? 0 : offs[tid - 1];
|
||||
delta = offs[tid] - start;
|
||||
if (K < 0) {
|
||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
|
||||
}
|
||||
offset = offs[tid];
|
||||
delta = offset - start;
|
||||
CUDA_KERNEL_ASSERT(delta >=0 && "expected gemm dimension to be greater or equal 0\n");
|
||||
|
||||
// TMA transfers require global memory tensor addresses to be
|
||||
// aligned to 16 bytes.
|
||||
@ -84,6 +86,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
int64_t lda, ldb, ldoutput;
|
||||
if (M < 0) {
|
||||
// A and output is 2d
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[0] && "expected offset to be less than tensor size\n");
|
||||
M = delta;
|
||||
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
||||
ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2];
|
||||
@ -96,6 +99,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput;
|
||||
B_ptrs[tid] = B + tid * tensor_StrideB[0];
|
||||
} else if (N < 0) {
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeB[1] && "expected offset to be less than tensor size\n");
|
||||
N = delta;
|
||||
lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2];
|
||||
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed
|
||||
@ -108,6 +112,7 @@ __global__ void prepare_grouped_gemm_data(
|
||||
inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1];
|
||||
}
|
||||
} else if (K < 0) {
|
||||
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[1] && offset <= tensor_ShapeB[0] && "expected offset to be less than tensor size\n");
|
||||
// A, B is 2d, output is 3d
|
||||
K = delta;
|
||||
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
|
||||
|
||||
@ -644,7 +644,12 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
|
||||
Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta))
|
||||
|
||||
// As above, there may be better configurations to use.
|
||||
constexpr int max_threads = std::is_same_v<scalar_t, float> ? 1024 : 896; // we need 72 or so 32 bit registers for double
|
||||
constexpr int max_threads_ = std::is_same_v<scalar_t, float> ? 1024 : 896; // we need 72 or so 32 bit registers for double
|
||||
int max_threads = max_threads_;
|
||||
// Blackwell launch bounds
|
||||
if (at::cuda::getCurrentDeviceProperties()->major >= 10) {
|
||||
max_threads = 512;
|
||||
}
|
||||
int threads_target = max_threads;
|
||||
while (threads_target / 2 >= 2*max_target_length+1) {
|
||||
threads_target /= 2;
|
||||
|
||||
@ -298,6 +298,9 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
Strides tensor_StrideA = make_strides(mat_a.strides());
|
||||
Strides tensor_StrideB = make_strides(mat_b.strides());
|
||||
Strides tensor_StrideOutput = make_strides(out.strides());
|
||||
Strides tensor_ShapeA = make_strides(mat_a.sizes());
|
||||
Strides tensor_ShapeB = make_strides(mat_b.sizes());
|
||||
|
||||
// scale stride will be used inside the kernel only if needed,
|
||||
// so for 1d scales the "1" assigned here won't be used
|
||||
int64_t a_scale_stride = scale_a.stride(0);
|
||||
@ -325,6 +328,8 @@ void f8f8bf16_grouped_gemm_impl_sm90(
|
||||
tensor_StrideA,
|
||||
tensor_StrideB,
|
||||
tensor_StrideOutput,
|
||||
tensor_ShapeA,
|
||||
tensor_ShapeB,
|
||||
a_scale_stride,
|
||||
b_scale_stride);
|
||||
|
||||
|
||||
74
aten/src/ATen/native/cuda/int8mm.cu
Normal file
74
aten/src/ATen/native/cuda/int8mm.cu
Normal file
@ -0,0 +1,74 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
namespace at::native {
|
||||
|
||||
__global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const float* scale, float* out, int B, int K, int N) {
|
||||
// one thread per output element: [B, N]
|
||||
int b = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
int n = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (b >= B || n >= N) return;
|
||||
|
||||
float acc = 0.0f;
|
||||
for (int k = 0; k < K; ++k) {
|
||||
acc += x[b * K + k] * static_cast<float>(w[n * K + k]);
|
||||
}
|
||||
|
||||
out[b * N + n] = acc * scale[n];
|
||||
}
|
||||
|
||||
void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8, const Tensor& scale, Tensor& out) {
|
||||
const int B = x.size(0);
|
||||
const int K = x.size(1);
|
||||
const int N = w_int8.size(0);
|
||||
|
||||
const dim3 block(16, 16);
|
||||
const dim3 grid((N + block.x - 1) / block.x, (B + block.y - 1) / block.y);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
weight_int8pack_mm_kernel<<<grid, block, 0, stream>>>(
|
||||
x.data_ptr<float>(),
|
||||
w_int8.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
out.data_ptr<float>(),
|
||||
B, K, N);
|
||||
}
|
||||
|
||||
|
||||
// Main GPU entry point
|
||||
at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int8, const at::Tensor& scale) {
|
||||
// --- Check inputs ---
|
||||
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
|
||||
TORCH_CHECK(w_int8.is_cuda(), "w must be a CUDA tensor");
|
||||
TORCH_CHECK(scale.is_cuda(), "scale must be a CUDA tensor");
|
||||
|
||||
TORCH_CHECK(x.dim() == 2, "x must be 2D");
|
||||
TORCH_CHECK(w_int8.dim() == 2, "w must be 2D");
|
||||
TORCH_CHECK(scale.dim() == 1, "scale must be 1D");
|
||||
|
||||
TORCH_CHECK(x.size(1) == w_int8.size(1), "K dimension mismatch: x.size(1) != w.size(1)");
|
||||
TORCH_CHECK(w_int8.size(0) == scale.size(0), "Output dim mismatch: w.size(0) != scale.size(0)");
|
||||
|
||||
// --- Determine shapes ---
|
||||
auto B = x.size(0); // batch size
|
||||
auto N = w_int8.size(0); // output dim
|
||||
|
||||
// Ensure inputs are in the correct types for the kernel
|
||||
auto x_f32 = x.to(at::kFloat);
|
||||
auto w_int8_contiguous = w_int8.contiguous();
|
||||
auto scale_f32 = scale.to(at::kFloat);
|
||||
|
||||
// --- Allocate output ---
|
||||
auto out = at::empty({B, N}, x.options().dtype(at::kFloat));
|
||||
|
||||
// --- Launch kernel ---
|
||||
launch_weight_int8pack_mm_cuda_kernel(x_f32, w_int8_contiguous, scale_f32, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
@ -28,6 +28,22 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
TORCH_CHECK(false, "cudnn_batch_norm: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<Tensor>& running_mean,
|
||||
const std::optional<Tensor>& running_var,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon,
|
||||
Tensor& out,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var,
|
||||
Tensor& reserve) {
|
||||
AT_ERROR("cudnn_batch_norm_out: ATen not compiled with cuDNN support");
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
@ -120,7 +136,12 @@ size_t _get_cudnn_batch_norm_reserve_space_size(
|
||||
return reserve_size;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
// Param `reserve` is a placeholder, just passing an empty tensor.
|
||||
// usage:
|
||||
// auto reserve = torch::empty({0}, torch::device(torch::kCUDA));
|
||||
// at::native::cudnn_batch_norm_out(..., epsilon, output, save_mean, save_var,
|
||||
// reserve);
|
||||
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_t_opt,
|
||||
@ -128,7 +149,11 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
const std::optional<Tensor>& running_var_t_opt,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon) {
|
||||
double epsilon,
|
||||
Tensor& output_t,
|
||||
Tensor& save_mean,
|
||||
Tensor& save_var,
|
||||
Tensor& reserve) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_t_maybe_owned =
|
||||
at::borrow_from_optional_tensor(bias_t_opt);
|
||||
@ -168,9 +193,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
|
||||
training, input->suggest_memory_format(), input->dim());
|
||||
|
||||
auto output_t =
|
||||
at::empty_like(*input, input->options(), input->suggest_memory_format());
|
||||
|
||||
TensorArg output{output_t, "output", 0};
|
||||
|
||||
auto handle = getCudnnHandle();
|
||||
@ -182,15 +204,8 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
|
||||
Constant one(dataType, 1);
|
||||
Constant zero(dataType, 0);
|
||||
Tensor save_mean, save_var;
|
||||
|
||||
Tensor reserve;
|
||||
|
||||
if (training) {
|
||||
int64_t num_features = input_t.size(1);
|
||||
save_mean = at::empty({num_features}, weight_t.options());
|
||||
save_var = at::empty({num_features}, weight_t.options());
|
||||
|
||||
auto op = CUDNN_BATCHNORM_OPS_BN;
|
||||
size_t workspace_size;
|
||||
AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
|
||||
@ -238,9 +253,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
reserve_size));
|
||||
} else {
|
||||
reserve = at::empty({0}, input->options().dtype(kByte));
|
||||
// This keeps a consistent output with native_batch_norm
|
||||
save_mean = at::empty({0}, weight_t.options());
|
||||
save_var = at::empty({0}, weight_t.options());
|
||||
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
|
||||
handle,
|
||||
mode,
|
||||
@ -261,10 +273,48 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
// save_mean and save_var can be undefined
|
||||
// If this causes problems, we can initialize them to empty tensors
|
||||
// of the correct type
|
||||
return std::tuple<Tensor, Tensor, Tensor, Tensor>{
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>{
|
||||
output_t, save_mean, save_var, reserve};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_t,
|
||||
const std::optional<Tensor>& bias_t_opt,
|
||||
const std::optional<Tensor>& running_mean_t_opt,
|
||||
const std::optional<Tensor>& running_var_t_opt,
|
||||
bool training,
|
||||
double exponential_average_factor,
|
||||
double epsilon) {
|
||||
auto output_t = at::empty_like(
|
||||
input_t, input_t.options(), input_t.suggest_memory_format());
|
||||
Tensor save_mean, save_var, reserve;
|
||||
|
||||
if (training) {
|
||||
int64_t num_features = input_t.size(1);
|
||||
save_mean = at::empty({num_features}, weight_t.options());
|
||||
save_var = at::empty({num_features}, weight_t.options());
|
||||
} else {
|
||||
// This keeps a consistent output with native_batch_norm
|
||||
save_mean = at::empty({0}, weight_t.options());
|
||||
save_var = at::empty({0}, weight_t.options());
|
||||
}
|
||||
|
||||
return cudnn_batch_norm_out(
|
||||
input_t,
|
||||
weight_t,
|
||||
bias_t_opt,
|
||||
running_mean_t_opt,
|
||||
running_var_t_opt,
|
||||
training,
|
||||
exponential_average_factor,
|
||||
epsilon,
|
||||
output_t,
|
||||
save_mean,
|
||||
save_var,
|
||||
reserve);
|
||||
}
|
||||
|
||||
// NB: CuDNN only implements the backward algorithm for batchnorm
|
||||
// in training mode (evaluation mode batchnorm has a different algorithm),
|
||||
// which is why this doesn't accept a 'training' parameter.
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/native/mkldnn/Matmul.h>
|
||||
|
||||
@ -428,56 +427,74 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool use_mkldnn_typed_matmul(
|
||||
bool use_mkldnn_bf16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
bool dtype_check = false;
|
||||
if constexpr (std::is_same_v<T, c10::BFloat16>) {
|
||||
#if defined(__aarch64__)
|
||||
if (mkldnn_bf16_device_check_arm()) {
|
||||
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
|
||||
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
|
||||
// inputs, allow it for float as well
|
||||
dtype_check = use_mkldnn_bf16_matmul() &&
|
||||
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16));
|
||||
}
|
||||
#else
|
||||
dtype_check = dtype_check && use_mkldnn_bf16_matmul() &&
|
||||
(mat1.scalar_type() == kBFloat16);
|
||||
if (mkldnn_bf16_device_check_arm()) {
|
||||
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
|
||||
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
|
||||
// inputs, allow it for float as well
|
||||
return (
|
||||
use_mkldnn_bf16_matmul() &&
|
||||
(mat1.scalar_type() == mat2.scalar_type()) &&
|
||||
(!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
|
||||
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
} else
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<T, c10::Half>) {
|
||||
dtype_check = dtype_check && use_mkldnn_fp16_matmul() &&
|
||||
(mat1.scalar_type() == kHalf);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
dtype_check = dtype_check &&
|
||||
(use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) &&
|
||||
(mat1.scalar_type() == kFloat);
|
||||
{
|
||||
return (
|
||||
use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 &&
|
||||
mat2.scalar_type() == kBFloat16 &&
|
||||
(!result.defined() || result.scalar_type() == kBFloat16) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
if (!dtype_check) {
|
||||
return false;
|
||||
}
|
||||
bool size_check =
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2);
|
||||
dtype_check = (mat1.scalar_type() == mat2.scalar_type()) &&
|
||||
(!result.defined() || result.scalar_type() == mat1.scalar_type());
|
||||
return dtype_check && size_check;
|
||||
}
|
||||
|
||||
bool use_mkldnn_fp16_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf &&
|
||||
mat2.scalar_type() == kHalf &&
|
||||
(!result.defined() || result.scalar_type() == kHalf) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_bf32_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat &&
|
||||
mat2.scalar_type() == kFloat &&
|
||||
(!result.defined() || result.scalar_type() == kFloat) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_tf32_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
return (
|
||||
use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat &&
|
||||
mat2.scalar_type() == kFloat &&
|
||||
(!result.defined() || result.scalar_type() == kFloat) &&
|
||||
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
|
||||
}
|
||||
|
||||
bool use_mkldnn_matmul(
|
||||
const Tensor& mat1,
|
||||
const Tensor& mat2,
|
||||
const Tensor& result) {
|
||||
auto mat1_type = mat1.scalar_type();
|
||||
if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) {
|
||||
return false;
|
||||
}
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] {
|
||||
return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result);
|
||||
});
|
||||
return false;
|
||||
return (
|
||||
use_mkldnn_bf16_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_fp16_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_bf32_matmul(mat1, mat2, result) ||
|
||||
use_mkldnn_tf32_matmul(mat1, mat2, result));
|
||||
}
|
||||
|
||||
static void _mkldnn_matmul_i8i8i32_with_primitive(
|
||||
|
||||
@ -469,4 +469,94 @@ Tensor _weight_int4pack_mm_xpu(
|
||||
|
||||
return C;
|
||||
}
|
||||
|
||||
Tensor& _int_mm_out_xpu(
|
||||
const Tensor& self,
|
||||
const Tensor& mat2,
|
||||
Tensor& result) {
|
||||
TORCH_CHECK(
|
||||
self.dim() == 2,
|
||||
"Expected self to be of dimension 2 but got ",
|
||||
self.dim());
|
||||
TORCH_CHECK(
|
||||
mat2.dim() == 2,
|
||||
"Expected mat2 to be of dimension 2 but got ",
|
||||
mat2.dim());
|
||||
TORCH_CHECK(
|
||||
self.size(1) == mat2.size(0),
|
||||
"self.size(1) needs to match mat2.size(0) but got ",
|
||||
self.size(1),
|
||||
" and ",
|
||||
mat2.size(0));
|
||||
|
||||
TORCH_CHECK(
|
||||
self.dtype() == at::kChar,
|
||||
"Expected self dtype to be of type int8 but got ",
|
||||
self.dtype());
|
||||
TORCH_CHECK(
|
||||
mat2.dtype() == at::kChar,
|
||||
"Expected mat2 dtype to be of type int8 but got ",
|
||||
mat2.dtype());
|
||||
TORCH_CHECK(
|
||||
result.dtype() == at::kInt,
|
||||
"Expected result dtype to be of type kInt but got ",
|
||||
result.dtype());
|
||||
TORCH_CHECK(
|
||||
result.size(0) == self.size(0),
|
||||
"Expected result.size(0) to be ",
|
||||
self.size(0),
|
||||
" but got ",
|
||||
result.size(0));
|
||||
TORCH_CHECK(
|
||||
result.size(1) == mat2.size(1),
|
||||
"Expected result.size(1) to be ",
|
||||
mat2.size(1),
|
||||
" but got ",
|
||||
result.size(1));
|
||||
|
||||
TORCH_CHECK(
|
||||
result.dim() == 2,
|
||||
"Expected result to be of dimension 2 but got ",
|
||||
result.dim());
|
||||
|
||||
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
|
||||
|
||||
if (result.numel() == 0 || self.size(1) == 0) {
|
||||
return result.zero_();
|
||||
}
|
||||
|
||||
Tensor bias = at::Tensor();
|
||||
Tensor mat2_scales = at::ones({1}, mat2.options().dtype(at::kFloat));
|
||||
Tensor mat2_zero_points = at::Tensor();
|
||||
auto post_op_args = torch::List<std::optional<at::Scalar>>();
|
||||
|
||||
at::native::onednn::quantized_matmul(
|
||||
self.contiguous(),
|
||||
1.0,
|
||||
0,
|
||||
mat2.contiguous(),
|
||||
mat2_scales,
|
||||
mat2_zero_points,
|
||||
bias,
|
||||
result,
|
||||
1.0,
|
||||
0,
|
||||
result.scalar_type(),
|
||||
/*other*/ std::nullopt,
|
||||
/*other scale*/ 1.0,
|
||||
/*other zp*/ 0,
|
||||
/*binary post op*/ "none",
|
||||
/*binary alpha*/ 1.0,
|
||||
/*post_op_name*/ "none",
|
||||
post_op_args,
|
||||
/*post_op_algorithm*/ "none",
|
||||
/*m2_trans*/ true);
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) {
|
||||
Tensor result =
|
||||
at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
|
||||
return _int_mm_out_xpu(self, mat2, result);
|
||||
}
|
||||
} // namespace at::native
|
||||
|
||||
@ -953,8 +953,7 @@ class BundledShaderLibary : public MetalShaderLibrary {
|
||||
if (C10_UNLIKELY(!library)) {
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
NSError* error = nil;
|
||||
auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic";
|
||||
library = [device newLibraryWithData:getSectionData(section_name) error:&error];
|
||||
library = [device newLibraryWithData:getSectionData("metal_basic") error:&error];
|
||||
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
|
||||
}
|
||||
return library;
|
||||
|
||||
@ -33,21 +33,15 @@ struct shrink_backward_functor {
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct hardsigmoid_functor {
|
||||
template <typename T>
|
||||
@ -67,15 +61,11 @@ struct hardsigmoid_backward_functor {
|
||||
|
||||
REGISTER_UNARY_OP(hardsigmoid, float, float);
|
||||
REGISTER_UNARY_OP(hardsigmoid, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct hardswish_functor {
|
||||
template <typename T>
|
||||
@ -103,15 +93,11 @@ struct hardswish_backward_functor {
|
||||
|
||||
REGISTER_UNARY_OP(hardswish, float, float);
|
||||
REGISTER_UNARY_OP(hardswish, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_OP(hardswish, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_OP(hardswish_backward, float, float);
|
||||
REGISTER_BINARY_OP(hardswish_backward, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct leaky_relu_functor {
|
||||
template <typename T>
|
||||
@ -135,12 +121,8 @@ struct leaky_relu_backward_functor {
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
@ -113,18 +113,12 @@ kernel void ampUpdateScale(
|
||||
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float);
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(float);
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(bfloat);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float);
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat);
|
||||
#endif
|
||||
|
||||
@ -590,9 +590,7 @@ kernel void attention(
|
||||
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(float);
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \
|
||||
template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||
@ -621,6 +619,4 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
|
||||
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(float);
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(bfloat);
|
||||
#endif
|
||||
|
||||
@ -209,38 +209,9 @@ struct hermite_polynomial_he_functor {
|
||||
};
|
||||
|
||||
struct nextafter_functor {
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename U>
|
||||
struct bit_type {};
|
||||
template <>
|
||||
struct bit_type<float> {
|
||||
using type = int;
|
||||
};
|
||||
template <>
|
||||
struct bit_type<half> {
|
||||
using type = short;
|
||||
};
|
||||
#endif
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
#if __METAL_VERSION__ >= 310
|
||||
return static_cast<T>(::metal::nextafter(a, b));
|
||||
#else
|
||||
using U = typename bit_type<T>::type;
|
||||
if (a == b) {
|
||||
return a;
|
||||
}
|
||||
if (::metal::isunordered(a, b)) {
|
||||
return NAN;
|
||||
}
|
||||
if (a == 0) {
|
||||
constexpr auto eps = as_type<T>(static_cast<U>(1));
|
||||
return b > 0 ? eps : -eps;
|
||||
}
|
||||
auto bits = as_type<U>(a);
|
||||
(a > 0) ^ (a > b) ? bits++ : bits--;
|
||||
return as_type<T>(bits);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -344,13 +315,6 @@ struct fmod_functor {
|
||||
}
|
||||
};
|
||||
|
||||
// Some helper defines
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#define _METAL_310_PLUS(x) x
|
||||
#else
|
||||
#define _METAL_310_PLUS(x)
|
||||
#endif
|
||||
|
||||
#define REGISTER_INTEGER_BINARY_OP(NAME) \
|
||||
REGISTER_BINARY_OP(NAME, long, long); \
|
||||
REGISTER_BINARY_OP(NAME, int, int); \
|
||||
@ -370,12 +334,12 @@ struct fmod_functor {
|
||||
#define REGISTER_FLOAT_BINARY_OP(NAME) \
|
||||
REGISTER_BINARY_OP(NAME, float, float); \
|
||||
REGISTER_BINARY_OP(NAME, half, half); \
|
||||
_METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat))
|
||||
REGISTER_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
#define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, float, float); \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
||||
_METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat))
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
REGISTER_FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
||||
@ -447,11 +411,9 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
// Complex binary functions
|
||||
REGISTER_BINARY_OP(polar, float, float2);
|
||||
|
||||
@ -180,10 +180,8 @@ REGISTER_SEARCHSORTED_OP(float, int);
|
||||
REGISTER_SEARCHSORTED_OP(float, long);
|
||||
REGISTER_SEARCHSORTED_OP(half, int);
|
||||
REGISTER_SEARCHSORTED_OP(half, long);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_SEARCHSORTED_OP(bfloat, int);
|
||||
REGISTER_SEARCHSORTED_OP(bfloat, long);
|
||||
#endif
|
||||
REGISTER_SEARCHSORTED_OP(char, int);
|
||||
REGISTER_SEARCHSORTED_OP(char, long);
|
||||
REGISTER_SEARCHSORTED_OP(uchar, int);
|
||||
|
||||
@ -96,6 +96,4 @@ kernel void col2im_kernel(
|
||||
INSTANTIATE_COL2IM(bool);
|
||||
INSTANTIATE_COL2IM(float);
|
||||
INSTANTIATE_COL2IM(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_COL2IM(bfloat);
|
||||
#endif
|
||||
|
||||
@ -20,9 +20,7 @@ REGISTER_CROSS_FUNC(short);
|
||||
REGISTER_CROSS_FUNC(char);
|
||||
REGISTER_CROSS_FUNC(uchar);
|
||||
REGISTER_CROSS_FUNC(bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_CROSS_FUNC(bfloat);
|
||||
#endif
|
||||
|
||||
template <typename T, typename U>
|
||||
kernel void cross(
|
||||
@ -68,6 +66,4 @@ REGISTER_CROSS_OP(short);
|
||||
REGISTER_CROSS_OP(char);
|
||||
REGISTER_CROSS_OP(uchar);
|
||||
REGISTER_CROSS_OP(bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_CROSS_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
using metal::max;
|
||||
#if __METAL_VERSION__ >= 310
|
||||
bfloat max(bfloat a, bfloat b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define kmaxThreadGroups 32
|
||||
#define kmaxTensors 32
|
||||
@ -306,11 +304,9 @@ REGISTER_ADAM_OPS_QUART(float, float);
|
||||
REGISTER_ADAM_OPS_QUART(float, half);
|
||||
REGISTER_ADAM_OPS_QUART(half, float);
|
||||
REGISTER_ADAM_OPS_QUART(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_ADAM_OPS_QUART(float, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, float);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
inline void sgd_momentum_math(
|
||||
@ -460,7 +456,5 @@ REGISTER_FUSED_SGD_OP(float);
|
||||
REGISTER_FUSED_SGD_OP(half);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_FUSED_SGD_OP(bfloat);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -106,9 +106,7 @@ kernel void polygamma(
|
||||
constant int64_t& order [[buffer(2)]], \
|
||||
uint id [[thread_position_in_grid]]);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_GAMMA_KERNELS(half, half);
|
||||
INSTANTIATE_GAMMA_KERNELS(float, float);
|
||||
INSTANTIATE_GAMMA_KERNELS(bool, float);
|
||||
|
||||
@ -76,6 +76,4 @@ INSTANTIATE_IM2COL(float);
|
||||
INSTANTIATE_IM2COL(float2);
|
||||
INSTANTIATE_IM2COL(half);
|
||||
INSTANTIATE_IM2COL(half2);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_IM2COL(bfloat);
|
||||
#endif
|
||||
|
||||
@ -240,9 +240,7 @@ REGISTER_INDEX_OP(put_accumulate, short, short);
|
||||
REGISTER_INDEX_OP(put_accumulate, char, char);
|
||||
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
|
||||
REGISTER_INDEX_OP(put_accumulate, bool, bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
template <typename StridesT, typename DataT>
|
||||
kernel void kernel_index_offsets(
|
||||
@ -477,10 +475,8 @@ INSTANTIATE_INDEX_COPY(char, long);
|
||||
INSTANTIATE_INDEX_COPY(uchar, int);
|
||||
INSTANTIATE_INDEX_COPY(uchar, long);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_INDEX_COPY(bfloat, int);
|
||||
INSTANTIATE_INDEX_COPY(bfloat, long);
|
||||
#endif
|
||||
INSTANTIATE_INDEX_COPY(float2, int);
|
||||
INSTANTIATE_INDEX_COPY(float2, long);
|
||||
INSTANTIATE_INDEX_COPY(half2, int);
|
||||
|
||||
@ -288,7 +288,6 @@ kernel void layer_norm_looped(
|
||||
#define instantiate_layer_norm(DTYPE) \
|
||||
instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE)
|
||||
|
||||
instantiate_layer_norm(float) instantiate_layer_norm(half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
instantiate_layer_norm(bfloat)
|
||||
#endif
|
||||
instantiate_layer_norm(float);
|
||||
instantiate_layer_norm(half);
|
||||
instantiate_layer_norm(bfloat);
|
||||
|
||||
@ -635,9 +635,7 @@ kernel void applyPivots(
|
||||
|
||||
INSTANTIATE_NAIVE_MM(float);
|
||||
INSTANTIATE_NAIVE_MM(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_NAIVE_MM(bfloat);
|
||||
#endif
|
||||
|
||||
// Integral MM
|
||||
INSTANTIATE_NAIVE_MM(short);
|
||||
|
||||
@ -48,3 +48,14 @@ struct PoolingBackwardParams {
|
||||
::c10::metal::array<idx_type_t, N> grad_output_strides;
|
||||
::c10::metal::array<idx_type_t, N> indices_strides;
|
||||
};
|
||||
|
||||
template <unsigned N = 5, typename idx_type_t = int32_t>
|
||||
struct MaxUnpoolingParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
::c10::metal::array<idx_type_t, N> input_sizes;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> indices_strides;
|
||||
};
|
||||
|
||||
@ -168,6 +168,16 @@ PoolOffsets find_pool_offsets(
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
case 3:
|
||||
return find_pool_offsets_dim_specific<3>(
|
||||
output_sizes,
|
||||
output_strides,
|
||||
indices_strides,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
}
|
||||
return PoolOffsets();
|
||||
}
|
||||
@ -292,6 +302,68 @@ kernel void max_pool_backward(
|
||||
pooling_dims);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void max_unpool_impl(
|
||||
device T* output,
|
||||
T input_element,
|
||||
int32_t input_index,
|
||||
constant int32_t* output_sizes,
|
||||
constant int32_t* output_strides,
|
||||
int32_t pooling_dims) {
|
||||
int32_t size_prod = 1;
|
||||
int32_t pool_offset = 0;
|
||||
|
||||
for (auto dim = pooling_dims - 1; dim >= 0; dim--) {
|
||||
auto next_size_prod = output_sizes[dim] * size_prod;
|
||||
pool_offset +=
|
||||
output_strides[dim] * ((input_index % next_size_prod) / size_prod);
|
||||
size_prod *= output_sizes[dim];
|
||||
}
|
||||
|
||||
output[pool_offset] = input_element;
|
||||
}
|
||||
|
||||
// Kernel computes one element of the grad input per kernel call.
|
||||
template <typename T>
|
||||
kernel void max_unpool(
|
||||
device T* output [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
constant int64_t* indices [[buffer(2)]],
|
||||
constant MaxUnpoolingParams<5>& params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto pooling_dims = params.pooling_dims;
|
||||
auto dims = params.dims;
|
||||
auto input_sizes = params.input_sizes.data();
|
||||
auto input_strides = params.input_strides.data();
|
||||
auto output_sizes = params.output_sizes.data();
|
||||
auto output_strides = params.output_strides.data();
|
||||
auto indices_strides = params.indices_strides.data();
|
||||
|
||||
auto leading_dims = dims - pooling_dims;
|
||||
|
||||
// NOTE: Since we're doing unpooling, the variable names "input" and "output"
|
||||
// are reversed compared to the pooling operations. So in `find_pool_offsets`,
|
||||
// we need to map "input" -> "output" and "output" -> "input".
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
/*output_sizes=*/input_sizes,
|
||||
/*output_strides=*/input_strides,
|
||||
indices_strides,
|
||||
/*input_strides=*/output_strides,
|
||||
/*pooling_dim_indices=*/nullptr,
|
||||
dims,
|
||||
leading_dims,
|
||||
/*return_indices=*/true,
|
||||
tid);
|
||||
|
||||
max_unpool_impl<T>(
|
||||
output + offsets.input_leading,
|
||||
input[offsets.output],
|
||||
indices[offsets.indices],
|
||||
output_sizes + leading_dims,
|
||||
output_strides + leading_dims,
|
||||
pooling_dims);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AvgPoolIterBounds {
|
||||
T start;
|
||||
@ -358,7 +430,6 @@ void avg_pool_3d_input_iter(
|
||||
auto divisor = has_divisor_override
|
||||
? divisor_override
|
||||
: (bounds0.count) * (bounds1.count) * (bounds2.count);
|
||||
auto size12 = input_sizes[1] * input_sizes[2];
|
||||
|
||||
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
|
||||
auto offset0 = input_strides[0] * i0;
|
||||
@ -376,6 +447,64 @@ void avg_pool_3d_input_iter(
|
||||
*output = value_sum / static_cast<T>(divisor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void avg_pool_backward_3d_input_iter(
|
||||
device AtomicType_t<T>* grad_input,
|
||||
constant T* grad_output,
|
||||
constant int32_t* grad_input_sizes,
|
||||
constant int32_t* grad_input_strides,
|
||||
int32_t grad_input_leading_offset,
|
||||
thread int32_t (&pooling_dim_indices)[3],
|
||||
constant int32_t* kernel_size,
|
||||
constant int32_t* stride,
|
||||
constant int32_t* padding,
|
||||
bool count_include_pad,
|
||||
bool has_divisor_override,
|
||||
int32_t divisor_override) {
|
||||
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
|
||||
grad_input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
|
||||
grad_input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
auto bounds2 = get_avg_pool_input_iter_bounds<2>(
|
||||
grad_input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
|
||||
auto divisor = has_divisor_override
|
||||
? divisor_override
|
||||
: (bounds0.count) * (bounds1.count) * (bounds2.count);
|
||||
auto grad_val = *grad_output / static_cast<T>(divisor);
|
||||
|
||||
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
|
||||
auto offset0 = grad_input_strides[0] * i0;
|
||||
|
||||
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
|
||||
auto offset1 = grad_input_strides[1] * i1;
|
||||
|
||||
for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
|
||||
auto offset2 = grad_input_strides[2] * i2;
|
||||
auto pool_offset = offset0 + offset1 + offset2;
|
||||
|
||||
AtomicType<T>::atomic_add(
|
||||
grad_input, grad_input_leading_offset + pool_offset, grad_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel computes one element of the output per kernel call.
|
||||
template <typename T>
|
||||
kernel void avg_pool(
|
||||
@ -428,31 +557,97 @@ kernel void avg_pool(
|
||||
params.divisor_override);
|
||||
}
|
||||
|
||||
#define REGISTER_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant PoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
template <typename T>
|
||||
kernel void avg_pool_backward(
|
||||
device AtomicType_t<T>* grad_input [[buffer(0)]],
|
||||
constant T* grad_output [[buffer(1)]],
|
||||
constant AvgPoolingParams<5>& params [[buffer(2)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
auto pooling_dims = params.pooling_dims;
|
||||
auto dims = params.dims;
|
||||
auto grad_input_sizes = params.input_sizes.data();
|
||||
auto grad_input_strides = params.input_strides.data();
|
||||
auto grad_output_sizes = params.output_sizes.data();
|
||||
auto grad_output_strides = params.output_strides.data();
|
||||
auto kernel_size = params.kernel_size.data();
|
||||
auto stride = params.stride.data();
|
||||
auto padding = params.padding.data();
|
||||
auto leading_dims = dims - pooling_dims;
|
||||
|
||||
// This buffer keeps track of the pooling dimension indices of this thread's
|
||||
// element of the output. We need to fill it with the proper values below.
|
||||
int32_t pooling_dim_indices[3];
|
||||
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
grad_output_sizes,
|
||||
grad_output_strides,
|
||||
/*indices_strides=*/nullptr,
|
||||
grad_input_strides,
|
||||
pooling_dim_indices,
|
||||
dims,
|
||||
leading_dims,
|
||||
/*return_indices=*/false,
|
||||
tid);
|
||||
|
||||
grad_output += offsets.output;
|
||||
grad_input_sizes += leading_dims;
|
||||
grad_input_strides += leading_dims;
|
||||
|
||||
avg_pool_backward_3d_input_iter<T>(
|
||||
grad_input,
|
||||
grad_output,
|
||||
grad_input_sizes,
|
||||
grad_input_strides,
|
||||
offsets.input_leading,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
params.count_include_pad,
|
||||
params.has_divisor_override,
|
||||
params.divisor_override);
|
||||
}
|
||||
|
||||
#define REGISTER_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant PoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("max_unpool_" #DTYPE)]] kernel void max_unpool<DTYPE>( \
|
||||
device DTYPE * output [[buffer(0)]], \
|
||||
constant DTYPE * input [[buffer(1)]], \
|
||||
constant int64_t* indices [[buffer(2)]], \
|
||||
constant MaxUnpoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
|
||||
#define REGISTER_POOL_BACKWARD_OP(DTYPE) \
|
||||
template [[host_name("max_pool_backward_" #DTYPE)]] \
|
||||
kernel void max_pool_backward<DTYPE>( \
|
||||
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
|
||||
constant DTYPE * grad_output_ [[buffer(1)]], \
|
||||
constant int64_t* grad_indices_ [[buffer(2)]], \
|
||||
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name("avg_pool_backward_" #DTYPE)]] \
|
||||
kernel void avg_pool_backward<DTYPE>( \
|
||||
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
|
||||
constant DTYPE * grad_output [[buffer(1)]], \
|
||||
constant AvgPoolingParams<5> & params [[buffer(2)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_POOL_OP(float);
|
||||
REGISTER_POOL_OP(half);
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_POOL_OP(int);
|
||||
REGISTER_POOL_OP(long);
|
||||
REGISTER_POOL_OP(short);
|
||||
@ -460,10 +655,6 @@ REGISTER_POOL_OP(char);
|
||||
REGISTER_POOL_OP(uchar);
|
||||
REGISTER_POOL_OP(bool);
|
||||
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(float);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(half);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
|
||||
#endif
|
||||
REGISTER_POOL_BACKWARD_OP(float);
|
||||
REGISTER_POOL_BACKWARD_OP(half);
|
||||
REGISTER_POOL_BACKWARD_OP(bfloat);
|
||||
|
||||
@ -197,12 +197,10 @@ INSTANTIATE_INT4MV(float, 128);
|
||||
INSTANTIATE_INT4MV(half, 128);
|
||||
INSTANTIATE_INT4MV(float, 256);
|
||||
INSTANTIATE_INT4MV(half, 256);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_INT4MV(bfloat, 32);
|
||||
INSTANTIATE_INT4MV(bfloat, 64);
|
||||
INSTANTIATE_INT4MV(bfloat, 128);
|
||||
INSTANTIATE_INT4MV(bfloat, 256);
|
||||
#endif
|
||||
|
||||
// ------------------------------ int8 MM For M >= 12 ------------------------------------
|
||||
/**
|
||||
@ -234,12 +232,10 @@ template <> struct BlockType<half> {
|
||||
using simdgroup_type8x8 = simdgroup_half8x8;
|
||||
using type4 = half4;
|
||||
};
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <> struct BlockType<bfloat> {
|
||||
using simdgroup_type8x8 = simdgroup_bfloat8x8;
|
||||
using type4 = bfloat4;
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) {
|
||||
@ -490,9 +486,7 @@ kernel void kernel_mul_mm<DTYPE, WDTYPE, DEQUANT_FUNC>( \
|
||||
|
||||
INSTANTIATE_MM(float, char, get_scale_zero_q8);
|
||||
INSTANTIATE_MM(half, char, get_scale_zero_q8);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_MM(bfloat, char, get_scale_zero_q8);
|
||||
#endif
|
||||
// ------------------------------ int8 MM For M < 12 ------------------------------------
|
||||
/* Matrix vector multiplication, used for small M size for matrix multiplication as well.
|
||||
|
||||
@ -646,6 +640,4 @@ kernel void kernel_mul_mv<DTYPE>(
|
||||
|
||||
INSTANTIATE_MV(float);
|
||||
INSTANTIATE_MV(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_MV(bfloat);
|
||||
#endif
|
||||
|
||||
@ -192,6 +192,4 @@ template <typename T>
|
||||
|
||||
instantiate_rms(float)
|
||||
instantiate_rms(half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
instantiate_rms(bfloat)
|
||||
#endif // clang-format on
|
||||
|
||||
@ -23,6 +23,4 @@ kernel void renorm(
|
||||
|
||||
REGISTER_RENORM_OP(float);
|
||||
REGISTER_RENORM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_RENORM_OP(bfloat);
|
||||
#endif
|
||||
|
||||
@ -25,379 +25,6 @@ struct LogAddExp {
|
||||
};
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct CumMinOp {
|
||||
static acc_t apply(acc_t a, acc_t b) {
|
||||
return metal::min(a, b);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return static_cast<acc_t>(
|
||||
metal::is_floating_point_v<T> ? metal::numeric_limits<T>::infinity()
|
||||
: metal::numeric_limits<T>::max());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct CumMaxOp {
|
||||
static acc_t apply(acc_t a, acc_t b) {
|
||||
return metal::max(a, b);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return static_cast<acc_t>(
|
||||
metal::is_floating_point_v<T> ? -metal::numeric_limits<T>::infinity()
|
||||
: metal::numeric_limits<T>::lowest());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct LogCumSumExpOp {
|
||||
static acc_t apply(acc_t x, acc_t y) {
|
||||
return LogAddExp{}(x, y);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return -metal::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
};
|
||||
|
||||
// Inclusive scan along innermost dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_contiguous_innermost_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant uint& num_rows [[buffer(2)]],
|
||||
constant uint& row_size [[buffer(3)]],
|
||||
uint row [[thread_position_in_grid]]) {
|
||||
if (row >= num_rows)
|
||||
return;
|
||||
|
||||
const uint offset = row * row_size;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
|
||||
for (uint col = 0; col < row_size; col++) {
|
||||
T val = input[offset + col];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[offset + col] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan along outer dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_contiguous_outer_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant uint& num_orows [[buffer(2)]],
|
||||
constant uint& num_irows [[buffer(3)]],
|
||||
constant uint& row_size [[buffer(4)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const uint orow = thread_index / num_irows;
|
||||
const uint irow = thread_index % num_irows;
|
||||
|
||||
if (orow >= num_orows)
|
||||
return;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
|
||||
const uint idx_base = orow * row_size * num_irows + irow;
|
||||
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
|
||||
T val = input[idx];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[idx] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan with indices along innermost dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_contiguous_innermost_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant uint& num_rows [[buffer(3)]],
|
||||
constant uint& row_size [[buffer(4)]],
|
||||
uint row [[thread_position_in_grid]]) {
|
||||
if (row >= num_rows)
|
||||
return;
|
||||
|
||||
const uint offset = row * row_size;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
|
||||
for (uint col = 0; col < row_size; col++) {
|
||||
T val = input[offset + col];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = col;
|
||||
}
|
||||
values[offset + col] = static_cast<T>(accumulator);
|
||||
indices[offset + col] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan with indices along outer dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_contiguous_outer_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant uint& num_orows [[buffer(3)]],
|
||||
constant uint& num_irows [[buffer(4)]],
|
||||
constant uint& row_size [[buffer(5)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const uint orow = thread_index / num_irows;
|
||||
const uint irow = thread_index % num_irows;
|
||||
|
||||
if (orow >= num_orows)
|
||||
return;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
|
||||
const uint idx_base = orow * row_size * num_irows + irow;
|
||||
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
|
||||
T val = input[idx];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = col;
|
||||
}
|
||||
values[idx] = static_cast<T>(accumulator);
|
||||
indices[idx] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Shared utility functions for strided kernels
|
||||
inline long calculate_non_scan_elements(
|
||||
constant long* sizes,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long total = 1;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
total *= sizes[i];
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
inline void thread_index_to_coordinates(
|
||||
uint index,
|
||||
int pos[c10::metal::max_ndim],
|
||||
constant long* sizes,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long remaining_index = index;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
pos[i] = remaining_index % sizes[i];
|
||||
remaining_index /= sizes[i];
|
||||
} else {
|
||||
pos[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline long calculate_base_offset(
|
||||
int pos[c10::metal::max_ndim],
|
||||
constant long* strides,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long offset = 0;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
offset += pos[i] * strides[i];
|
||||
}
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
// Generic strided scan kernel
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_strided(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant long* sizes [[buffer(2)]],
|
||||
constant long* input_strides [[buffer(3)]],
|
||||
constant long* output_strides [[buffer(4)]],
|
||||
constant uint& ndim [[buffer(5)]],
|
||||
constant uint& scan_dim [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const long total_non_scan_elements =
|
||||
calculate_non_scan_elements(sizes, ndim, scan_dim);
|
||||
if (thread_index >= total_non_scan_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos[c10::metal::max_ndim];
|
||||
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
|
||||
|
||||
const long input_base_offset =
|
||||
calculate_base_offset(pos, input_strides, ndim, scan_dim);
|
||||
const long output_base_offset =
|
||||
calculate_base_offset(pos, output_strides, ndim, scan_dim);
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
const long scan_size = sizes[scan_dim];
|
||||
const long input_scan_stride = input_strides[scan_dim];
|
||||
const long output_scan_stride = output_strides[scan_dim];
|
||||
|
||||
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
|
||||
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
|
||||
const long output_offset =
|
||||
output_base_offset + scan_idx * output_scan_stride;
|
||||
|
||||
T val = input[input_offset];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[output_offset] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Generic strided scan with indices kernel
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_strided(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant long* sizes [[buffer(3)]],
|
||||
constant long* input_strides [[buffer(4)]],
|
||||
constant long* values_strides [[buffer(5)]],
|
||||
constant long* indices_strides [[buffer(6)]],
|
||||
constant uint& ndim [[buffer(7)]],
|
||||
constant uint& scan_dim [[buffer(8)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const long total_non_scan_elements =
|
||||
calculate_non_scan_elements(sizes, ndim, scan_dim);
|
||||
if (thread_index >= total_non_scan_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos[c10::metal::max_ndim];
|
||||
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
|
||||
|
||||
const long input_base_offset =
|
||||
calculate_base_offset(pos, input_strides, ndim, scan_dim);
|
||||
const long values_base_offset =
|
||||
calculate_base_offset(pos, values_strides, ndim, scan_dim);
|
||||
const long indices_base_offset =
|
||||
calculate_base_offset(pos, indices_strides, ndim, scan_dim);
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
const long scan_size = sizes[scan_dim];
|
||||
const long input_scan_stride = input_strides[scan_dim];
|
||||
const long values_scan_stride = values_strides[scan_dim];
|
||||
const long indices_scan_stride = indices_strides[scan_dim];
|
||||
|
||||
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
|
||||
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
|
||||
const long values_offset =
|
||||
values_base_offset + scan_idx * values_scan_stride;
|
||||
const long indices_offset =
|
||||
indices_base_offset + scan_idx * indices_scan_stride;
|
||||
|
||||
T val = input[input_offset];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = scan_idx;
|
||||
}
|
||||
values[values_offset] = static_cast<T>(accumulator);
|
||||
indices[indices_offset] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \
|
||||
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
|
||||
scan_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant uint & num_rows [[buffer(2)]], \
|
||||
constant uint & row_size [[buffer(3)]], \
|
||||
uint row [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
|
||||
scan_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant uint & num_orows [[buffer(2)]], \
|
||||
constant uint & num_irows [[buffer(3)]], \
|
||||
constant uint & row_size [[buffer(4)]], \
|
||||
uint thread_index [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
|
||||
scan_strided<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant long* sizes [[buffer(2)]], \
|
||||
constant long* input_strides [[buffer(3)]], \
|
||||
constant long* output_strides [[buffer(4)]], \
|
||||
constant uint& ndim [[buffer(5)]], \
|
||||
constant uint& scan_dim [[buffer(6)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \
|
||||
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant uint& num_rows [[buffer(3)]], \
|
||||
constant uint& row_size [[buffer(4)]], \
|
||||
uint row [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant uint& num_orows [[buffer(3)]], \
|
||||
constant uint& num_irows [[buffer(4)]], \
|
||||
constant uint& row_size [[buffer(5)]], \
|
||||
uint thread_index [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_strided<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant long* sizes [[buffer(3)]], \
|
||||
constant long* input_strides [[buffer(4)]], \
|
||||
constant long* values_strides [[buffer(5)]], \
|
||||
constant long* indices_strides [[buffer(6)]], \
|
||||
constant uint& ndim [[buffer(7)]], \
|
||||
constant uint& scan_dim [[buffer(8)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
// Simple scan operations
|
||||
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float);
|
||||
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half);
|
||||
|
||||
// Scan operations with indices
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool);
|
||||
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
|
||||
|
||||
#else // __METAL_VERSION__ >= 310
|
||||
|
||||
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
|
||||
|
||||
// The reminder of this file contains cummin and cummax implementations adapted
|
||||
@ -1159,5 +786,3 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4);
|
||||
|
||||
#endif
|
||||
|
||||
@ -89,6 +89,4 @@ REGISTER_SPECIAL(short, float);
|
||||
REGISTER_SPECIAL(int, float);
|
||||
REGISTER_SPECIAL(long, float);
|
||||
REGISTER_SPECIAL(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_SPECIAL(bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
@ -100,9 +100,7 @@ kernel void triul(
|
||||
|
||||
INSTANTIATE_TRIUL_KERNELS(float, int);
|
||||
INSTANTIATE_TRIUL_KERNELS(half, int);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_TRIUL_KERNELS(bfloat, int);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_TRIUL_KERNELS(float2, int);
|
||||
INSTANTIATE_TRIUL_KERNELS(half2, int);
|
||||
|
||||
@ -556,11 +556,9 @@ REGISTER_UNARY_OP(abs, half, half);
|
||||
REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
|
||||
REGISTER_UNARY_OP(neg, bfloat, bfloat);
|
||||
REGISTER_UNARY_OP(abs, bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_UNARY_KERNELS2(half, half);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, float);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, bool);
|
||||
@ -600,6 +598,4 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float);
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float);
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat);
|
||||
#endif
|
||||
|
||||
@ -70,6 +70,4 @@ kernel void unfold_backward(
|
||||
|
||||
INSTANTIATE_UNFOLD_BACKWARD(float);
|
||||
INSTANTIATE_UNFOLD_BACKWARD(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNFOLD_BACKWARD(bfloat);
|
||||
#endif
|
||||
|
||||
@ -852,6 +852,4 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar);
|
||||
INSTANTIATE_UPSAMPLE_3D(uchar);
|
||||
INSTANTIATE_UPSAMPLE_ALL(float);
|
||||
INSTANTIATE_UPSAMPLE_ALL(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UPSAMPLE_ALL(bfloat);
|
||||
#endif
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <ATen/ops/avg_pool2d_backward.h>
|
||||
#include <ATen/ops/avg_pool2d_backward_native.h>
|
||||
#include <ATen/ops/avg_pool2d_native.h>
|
||||
#include <ATen/ops/avg_pool3d_backward_native.h>
|
||||
#include <ATen/ops/avg_pool3d_native.h>
|
||||
#include <ATen/ops/max_pool2d_backward_native.h>
|
||||
#include <ATen/ops/max_pool2d_native.h>
|
||||
@ -21,6 +22,8 @@
|
||||
#include <ATen/ops/max_pool2d_with_indices_native.h>
|
||||
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
|
||||
#include <ATen/ops/max_pool3d_with_indices_native.h>
|
||||
#include <ATen/ops/max_unpool2d_native.h>
|
||||
#include <ATen/ops/max_unpool3d_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
@ -492,6 +495,60 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
|
||||
});
|
||||
}
|
||||
|
||||
static void max_unpool_out_mps_template(const Tensor& input,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size_,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
Tensor& output,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto dims = input.dim();
|
||||
auto leading_dims = input.dim() - pooling_dims;
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
std::vector<int64_t> output_size(dims);
|
||||
for (int dim : c10::irange(leading_dims)) {
|
||||
output_size[dim] = input.sizes()[dim];
|
||||
}
|
||||
for (int dim : c10::irange(pooling_dims)) {
|
||||
output_size[leading_dims + dim] = output_size_[dim];
|
||||
}
|
||||
|
||||
output.resize_(output_size, memory_format);
|
||||
output.fill_(0);
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const auto numThreads = input.numel();
|
||||
MaxUnpoolingParams<5> params;
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
|
||||
for (const auto dim : c10::irange(dims)) {
|
||||
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
|
||||
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
|
||||
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
|
||||
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
|
||||
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto PSO = lib.getPipelineStateForFunc("max_unpool_" + scalarToMetalTypeString(input));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(PSO, op_name, {input});
|
||||
[computeEncoder setComputePipelineState:PSO];
|
||||
mtl_setArgs(computeEncoder, output, input, indices, params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(PSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void avg_pool2d_template(const Tensor& input,
|
||||
const Tensor& output,
|
||||
const std::optional<Tensor>& grad_output_opt,
|
||||
@ -669,6 +726,64 @@ static void avg_pool_out_mps_template(const Tensor& output,
|
||||
});
|
||||
}
|
||||
|
||||
static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
|
||||
const Tensor& input,
|
||||
const Tensor& grad_output,
|
||||
IntArrayRef _kernel_size,
|
||||
IntArrayRef _stride,
|
||||
IntArrayRef _padding,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
std::optional<int64_t> divisor_override,
|
||||
const int32_t pooling_dims,
|
||||
const std::string& op_name) {
|
||||
auto [dims, _, kernel_size, stride, padding, __] =
|
||||
process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name);
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
grad_input.resize_(input.sizes(), memory_format);
|
||||
grad_input.fill_(0);
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const auto numThreads = grad_output.numel();
|
||||
|
||||
AvgPoolingParams<5> params;
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
params.count_include_pad = count_include_pad;
|
||||
params.has_divisor_override = divisor_override.has_value();
|
||||
if (divisor_override.has_value()) {
|
||||
params.divisor_override = safe_downcast<int32_t, int64_t>(divisor_override.value());
|
||||
}
|
||||
|
||||
for (const auto dim : c10::irange(dims)) {
|
||||
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_output.size(dim));
|
||||
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(grad_output.stride(dim));
|
||||
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_input.size(dim));
|
||||
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(grad_input.stride(dim));
|
||||
}
|
||||
|
||||
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto PSO = lib.getPipelineStateForFunc("avg_pool_backward_" + scalarToMetalTypeString(input));
|
||||
|
||||
getMPSProfiler().beginProfileKernel(PSO, op_name, {grad_output});
|
||||
[computeEncoder setComputePipelineState:PSO];
|
||||
mtl_setArgs(computeEncoder, grad_input, grad_output, params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(PSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor mps_max_pool2d(const Tensor& input,
|
||||
@ -896,6 +1011,68 @@ Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output,
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
Tensor& max_unpooling2d_forward_out_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
Tensor& output) {
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
/*stride=*/{},
|
||||
/*padding=*/{},
|
||||
output,
|
||||
/*pooling_dims=*/2,
|
||||
"max_unpool2d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor max_unpooling2d_forward_mps(const Tensor& self, const Tensor& indices, IntArrayRef output_size) {
|
||||
auto output = at::empty({0}, self.options());
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
/*stride=*/{},
|
||||
/*padding=*/{},
|
||||
output,
|
||||
/*pooling_dims=*/2,
|
||||
"max_unpool2d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor& max_unpooling3d_forward_out_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
Tensor& output) {
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
stride,
|
||||
padding,
|
||||
output,
|
||||
/*pooling_dims=*/3,
|
||||
"max_unpool3d");
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor max_unpooling3d_forward_mps(const Tensor& self,
|
||||
const Tensor& indices,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding) {
|
||||
auto output = at::empty({0}, self.options());
|
||||
mps::max_unpool_out_mps_template(self,
|
||||
indices,
|
||||
output_size,
|
||||
stride,
|
||||
padding,
|
||||
output,
|
||||
/*pooling_dims=*/3,
|
||||
"max_unpool3d");
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
|
||||
(const Tensor& input,
|
||||
int64_t kH,
|
||||
@ -965,4 +1142,26 @@ TORCH_IMPL_FUNC(avg_pool3d_out_mps)
|
||||
"avg_pool3d");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool3d_backward_out_mps)(const Tensor& grad_output,
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
bool ceil_mode,
|
||||
bool count_include_pad,
|
||||
std::optional<int64_t> divisor_override,
|
||||
const Tensor& grad_input) {
|
||||
mps::avg_pool_backward_out_mps_template(grad_input,
|
||||
input,
|
||||
grad_output,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
/*pooling_dims=*/3,
|
||||
"avg_pool3d_backward");
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -719,6 +719,7 @@
|
||||
dispatch:
|
||||
CPU, CUDA: all_out
|
||||
MPS: all_out_mps
|
||||
MTIA: all_out_mtia
|
||||
|
||||
- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -808,6 +809,7 @@
|
||||
CPU, Meta: arange_out
|
||||
CUDA: arange_cuda_out
|
||||
MPS: arange_mps_out
|
||||
MTIA: arange_mtia_out
|
||||
cpp_no_default_args: ['step']
|
||||
|
||||
# This function is a temporary hack to allow tracing of arange like constructs with dynamic
|
||||
@ -1889,7 +1891,10 @@
|
||||
- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: cudnn_batch_norm
|
||||
autogen: cudnn_batch_norm.out
|
||||
|
||||
- func: cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
|
||||
dispatch:
|
||||
CUDA: cudnn_batch_norm_out
|
||||
|
||||
# NB: You can only use this if you used cudnn_batch_norm training=True
|
||||
- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
|
||||
@ -4182,11 +4187,13 @@
|
||||
dispatch:
|
||||
CPU: _int_mm_cpu
|
||||
CUDA: _int_mm_cuda
|
||||
XPU: _int_mm_xpu
|
||||
|
||||
- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CPU: _int_mm_out_cpu
|
||||
CUDA: _int_mm_out_cuda
|
||||
XPU: _int_mm_out_xpu
|
||||
|
||||
- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
|
||||
dispatch:
|
||||
@ -4223,6 +4230,7 @@
|
||||
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
|
||||
dispatch:
|
||||
CPU: _weight_int8pack_mm_cpu
|
||||
CUDA: _weight_int8pack_mm_cuda
|
||||
MPS: _weight_int8pack_mm_mps
|
||||
|
||||
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
|
||||
@ -7124,18 +7132,21 @@
|
||||
dispatch:
|
||||
CPU: _scaled_mm_cpu
|
||||
CUDA: _scaled_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU: _scaled_mm_out_cpu
|
||||
CUDA: _scaled_mm_out_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
|
||||
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
@ -10487,6 +10498,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_add_scalar_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_scalar_kernel_mtia_
|
||||
autogen: _foreach_add.Scalar_out
|
||||
|
||||
- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
|
||||
@ -10495,6 +10507,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda
|
||||
MTIA: foreach_tensor_add_list_kernel_mtia
|
||||
|
||||
- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10502,6 +10515,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_
|
||||
CUDA: foreach_tensor_add_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_list_kernel_mtia_
|
||||
autogen: _foreach_add.List_out
|
||||
|
||||
- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
@ -10532,6 +10546,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_
|
||||
CUDA: foreach_tensor_add_tensor_kernel_cuda_
|
||||
MTIA: foreach_tensor_add_tensor_kernel_mtia_
|
||||
autogen: _foreach_add.Tensor_out
|
||||
|
||||
- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
@ -10592,6 +10607,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_scalar_kernel_mtia_
|
||||
autogen: _foreach_mul.Scalar_out
|
||||
|
||||
- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]
|
||||
@ -10600,6 +10616,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda
|
||||
MTIA: foreach_tensor_mul_list_kernel_mtia
|
||||
|
||||
- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10607,6 +10624,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_list_kernel_mtia_
|
||||
autogen: _foreach_mul.List_out
|
||||
|
||||
- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
|
||||
@ -10630,6 +10648,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow
|
||||
CUDA: foreach_tensor_mul_tensor_kernel_cuda
|
||||
MTIA: foreach_tensor_mul_tensor_kernel_mtia
|
||||
|
||||
- func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10637,6 +10656,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_
|
||||
CUDA: foreach_tensor_mul_tensor_kernel_cuda_
|
||||
MTIA: foreach_tensor_mul_tensor_kernel_mtia_
|
||||
autogen: _foreach_mul.Tensor_out
|
||||
|
||||
- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
|
||||
@ -10933,6 +10953,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda
|
||||
MTIA: foreach_tensor_addcmul_scalar_mtia
|
||||
|
||||
- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10954,6 +10975,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_
|
||||
CUDA: foreach_tensor_addcmul_scalar_cuda_
|
||||
MTIA: foreach_tensor_addcmul_scalar_mtia_
|
||||
autogen: _foreach_addcmul.Scalar_out
|
||||
|
||||
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
|
||||
@ -10978,6 +11000,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_abs_slow
|
||||
CUDA: foreach_tensor_abs_cuda
|
||||
MTIA: foreach_tensor_abs_mtia
|
||||
|
||||
- func: _foreach_abs_(Tensor(a!)[] self) -> ()
|
||||
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
|
||||
@ -10985,6 +11008,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_abs_slow_
|
||||
CUDA: foreach_tensor_abs_cuda_
|
||||
MTIA: foreach_tensor_abs_mtia_
|
||||
autogen: _foreach_abs.out
|
||||
|
||||
- func: _foreach_acos(Tensor[] self) -> Tensor[]
|
||||
@ -11319,6 +11343,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_norm_slow
|
||||
CUDA: foreach_tensor_norm_cuda
|
||||
MTIA: foreach_tensor_norm_mtia
|
||||
autogen: _foreach_norm.Scalar_out
|
||||
|
||||
- func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
|
||||
@ -11491,6 +11516,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_sqrt_slow_
|
||||
CUDA: foreach_tensor_sqrt_cuda_
|
||||
MTIA: foreach_tensor_sqrt_mtia_
|
||||
autogen: _foreach_sqrt.out
|
||||
|
||||
- func: _foreach_tan(Tensor[] self) -> Tensor[]
|
||||
@ -11552,6 +11578,7 @@
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_
|
||||
CUDA: foreach_tensor_copy_list_kernel_cuda_
|
||||
MTIA: foreach_tensor_copy_list_kernel_mtia_
|
||||
autogen: _foreach_copy.out
|
||||
|
||||
- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
|
||||
@ -11559,6 +11586,7 @@
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _foreach_copy
|
||||
MTIA: foreach_tensor_copy_list_kernel_mtia
|
||||
|
||||
- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
|
||||
dispatch:
|
||||
@ -12351,6 +12379,7 @@
|
||||
dispatch:
|
||||
CPU: avg_pool3d_backward_out_cpu
|
||||
CUDA: avg_pool3d_backward_out_cuda
|
||||
MPS: avg_pool3d_backward_out_mps
|
||||
MkldnnCPU: mkldnn_avg_pool3d_backward_out
|
||||
|
||||
- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
|
||||
@ -12476,24 +12505,28 @@
|
||||
dispatch:
|
||||
CPU: max_unpooling2d_forward_out_cpu
|
||||
CUDA: max_unpooling2d_forward_out_cuda
|
||||
MPS: max_unpooling2d_forward_out_mps
|
||||
|
||||
- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling2d_forward_cpu
|
||||
CUDA: max_unpooling2d_forward_cuda
|
||||
MPS: max_unpooling2d_forward_mps
|
||||
|
||||
- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling3d_forward_out_cpu
|
||||
CUDA: max_unpooling3d_forward_out_cuda
|
||||
MPS: max_unpooling3d_forward_out_mps
|
||||
|
||||
- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
|
||||
python_module: nn
|
||||
dispatch:
|
||||
CPU: max_unpooling3d_forward_cpu
|
||||
CUDA: max_unpooling3d_forward_cuda
|
||||
MPS: max_unpooling3d_forward_mps
|
||||
|
||||
- func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: nn
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
--api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -11,7 +11,27 @@ endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
--api fwd_splitkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_SPLITKV kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd_appendkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_APPENDKV kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -19,15 +39,29 @@ if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
# Generate the files for both fwd and bwd
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
# Generate the files for both fwd, fwd_splitkv, fwd_appendkv, and bwd
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_SPLITKV kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_appendkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_APPENDKV kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -44,6 +78,22 @@ if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt"
|
||||
RESULT_VARIABLE ret)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd_splitkv pass")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt"
|
||||
RESULT_VARIABLE ret)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd appendkv pass")
|
||||
endif()
|
||||
|
||||
# Change make_kernel to make_kernel_pt for bwd
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt"
|
||||
|
||||
@ -21,6 +21,8 @@ while IFS= read -r file; do
|
||||
if [ -f "$file" ]; then
|
||||
# Use sed to replace "make_kernel" with "make_kernel_pt" in place
|
||||
sed -i 's/make_kernel/make_kernel_pt/g' "$file"
|
||||
sed -i 's/\#include \"fmha_fwd.hpp\"/\#include \"fmha_fwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
|
||||
sed -i 's/\#include \"fmha_bwd.hpp\"/\#include \"fmha_bwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
|
||||
echo "Updated: $file"
|
||||
else
|
||||
echo "Skipping: $file (not found)"
|
||||
|
||||
@ -1,100 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
// keep sync with BlockAttentionBiasEnum
|
||||
enum class bias_enum
|
||||
{
|
||||
no_bias = 0,
|
||||
elementwise_bias = 1,
|
||||
alibi = 2,
|
||||
};
|
||||
|
||||
struct bias_info
|
||||
{
|
||||
bias_enum type;
|
||||
/*
|
||||
* simple dispatch logic
|
||||
*
|
||||
* if type == elementwise_bias:
|
||||
* if rank_info == 0:
|
||||
* bias is 1*1*s*s
|
||||
* elif rank_info == 1:
|
||||
* bias is 1*h*s*s
|
||||
* elif rank_info == 2:
|
||||
* bias is b*h*s*s
|
||||
*
|
||||
* elif type == alibi:
|
||||
* if rank_info == 0:
|
||||
* alibi in 1*h
|
||||
* elif rank_info == 1:
|
||||
* alibi in b*h
|
||||
*/
|
||||
int rank_info;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == bias_enum::no_bias)
|
||||
os << "n";
|
||||
else if(type == bias_enum::elementwise_bias)
|
||||
{
|
||||
os << "e";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
else if(type == bias_enum::alibi)
|
||||
{
|
||||
os << "alibi";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bias_info decode(std::string str)
|
||||
{
|
||||
bias_info info{bias_enum::no_bias, 0};
|
||||
if(str == "0" || str == "n")
|
||||
{
|
||||
info.type = bias_enum::no_bias;
|
||||
}
|
||||
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
|
||||
str.compare(0, 11, "elementwise") == 0)
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
}
|
||||
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
|
||||
str.compare(0, 5, "alibi") == 0)
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string e = str.substr(found_0 + 1);
|
||||
info.rank_info = atoi(e.c_str());
|
||||
}
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
|
||||
{
|
||||
bi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@ -1,457 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <mask.hpp>
|
||||
#include <bias.hpp>
|
||||
#include <launch_kernel_pt.hpp>
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaBwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaBwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using GemmDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::half_t;
|
||||
using OGradDataType = ck_tile::half_t;
|
||||
using QGradDataType = ck_tile::half_t;
|
||||
using KGradDataType = ck_tile::half_t;
|
||||
using VGradDataType = ck_tile::half_t;
|
||||
using BiasGradDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using GemmDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using LSEDataType = float;
|
||||
using AccDataType = float; // data type for gemm accumulation
|
||||
using DDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using OGradDataType = ck_tile::bf16_t;
|
||||
using QGradDataType = ck_tile::bf16_t;
|
||||
using KGradDataType = ck_tile::bf16_t;
|
||||
using VGradDataType = ck_tile::bf16_t;
|
||||
using BiasGradDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct fmha_bwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* o_ptr;
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
void* d_ptr;
|
||||
void* rand_val_ptr;
|
||||
void* dq_ptr;
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t max_seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
float scale;
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_dq_acc;
|
||||
ck_tile::index_t stride_dq;
|
||||
ck_tile::index_t stride_dk;
|
||||
ck_tile::index_t stride_dv;
|
||||
ck_tile::index_t stride_dbias;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
ck_tile::index_t nhead_stride_dq_acc;
|
||||
ck_tile::index_t nhead_stride_dq;
|
||||
ck_tile::index_t nhead_stride_dk;
|
||||
ck_tile::index_t nhead_stride_dv;
|
||||
ck_tile::index_t nhead_stride_dbias;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
ck_tile::index_t batch_stride_dq_acc;
|
||||
ck_tile::index_t batch_stride_dq;
|
||||
ck_tile::index_t batch_stride_dk;
|
||||
ck_tile::index_t batch_stride_dv;
|
||||
ck_tile::index_t batch_stride_dbias;
|
||||
ck_tile::index_t split_stride_dq_acc;
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaBwdDQDKDVKernel>
|
||||
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
args.split_stride_dq_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdOGradDotOKernel>
|
||||
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.p_undrop,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_o,
|
||||
args.nhead_stride_lsed,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_o,
|
||||
args.batch_stride_lsed);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename FmhaBwdConvertQGradKernel>
|
||||
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
|
||||
args.dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
args.nhead_stride_dq,
|
||||
args.nhead_stride_dq_acc,
|
||||
args.batch_stride_dq,
|
||||
args.batch_stride_dq_acc,
|
||||
args.split_stride_dq_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dot_do_o_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_convert_dq_get_name_();
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_bwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
template <int Version = 2>
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
@ -1,824 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include <bias.hpp>
|
||||
#include <mask.hpp>
|
||||
#include <rotary.hpp>
|
||||
#include <launch_kernel_pt.hpp>
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaFwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdBf8
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Fp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
using VDataType = ck_tile::half_t;
|
||||
using BiasDataType = ck_tile::half_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
using VDataType = ck_tile::bf16_t;
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::fp8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdBf8>
|
||||
{
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using KDataType = ck_tile::bf8_t;
|
||||
using VDataType = ck_tile::bf8_t;
|
||||
using BiasDataType = ck_tile::bf8_t;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
|
||||
// runtime args, some will passed to karg, some will used to compute grids/blocks
|
||||
struct fmha_fwd_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_randval;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_randval;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
struct fmha_fwd_splitkv_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* lse_acc_ptr;
|
||||
void* o_acc_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
|
||||
// nullptr.
|
||||
|
||||
const void* cache_batch_idx;
|
||||
|
||||
// the real seqlen_q & seqlen_k are decided by following:
|
||||
// batch mode: seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k
|
||||
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
//
|
||||
// batch mode (kvcache):
|
||||
// seqlen_q = kargs.seqlen_q
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// group mode (kvcache):
|
||||
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
|
||||
//
|
||||
// when is_gappy=true:
|
||||
// seqlen_k = kargs.seqlen_k_ptr[b]
|
||||
// seqstart_k_ptr[b] now store local offset of each batch
|
||||
//
|
||||
// when is_gappy=false:
|
||||
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
|
||||
// or kargs.seqlen_k_ptr[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
ck_tile::index_t num_splits;
|
||||
|
||||
float scale_s;
|
||||
float scale_p;
|
||||
float scale_o;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o_acc;
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_lse;
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_lse;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
|
||||
ck_tile::index_t window_size_left;
|
||||
ck_tile::index_t window_size_right;
|
||||
ck_tile::index_t mask_type;
|
||||
};
|
||||
|
||||
struct fmha_fwd_appendkv_args
|
||||
{
|
||||
void* q_ptr;
|
||||
void* k_ptr;
|
||||
const void* knew_ptr;
|
||||
void* v_ptr;
|
||||
const void* vnew_ptr;
|
||||
|
||||
const void* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_knew;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
|
||||
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
|
||||
ck_tile::index_t rotary_dim;
|
||||
bool has_mask;
|
||||
|
||||
void* block_table_ptr;
|
||||
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_knew;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_vnew;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_knew;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_vnew;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_knew;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_vnew;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel>
|
||||
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return FmhaKernel::MakeKargsImpl(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.rand_val_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.scale_o,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
if constexpr(FmhaKernel::kIsGroupMode)
|
||||
{
|
||||
dim3 grids = FmhaKernel::GridSize(
|
||||
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 grids =
|
||||
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel arguments
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.is_gappy,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_k, // only used for paged-kvcache
|
||||
args.batch_stride_v, // only used for paged-kvcache
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.bias_ptr,
|
||||
args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.batch,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.seqlen_k_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.num_splits,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.scale_s,
|
||||
args.scale_p,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_bias,
|
||||
args.stride_o_acc,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_bias,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(
|
||||
args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
// create group mode kernel argumentszs
|
||||
if constexpr(Kernel::kIsGroupMode)
|
||||
{
|
||||
return Kernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
return Kernel::MakeKargs(args.lse_acc_ptr,
|
||||
args.o_acc_ptr,
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
args.scale_o,
|
||||
args.stride_o_acc,
|
||||
args.stride_o,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_lse,
|
||||
args.batch_stride_o,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename Kernel>
|
||||
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.knew_ptr,
|
||||
args.v_ptr,
|
||||
args.vnew_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k_ptr,
|
||||
args.seqlen_knew,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.has_mask,
|
||||
args.block_table_ptr,
|
||||
args.batch_stride_block_table,
|
||||
args.page_block_size,
|
||||
args.cache_batch_idx,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
args.stride_v,
|
||||
args.stride_vnew,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_knew,
|
||||
args.nhead_stride_v,
|
||||
args.nhead_stride_vnew,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_knew,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_vnew);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
|
||||
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kIsPagedKV_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_splitkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_get_name_();
|
||||
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kN1_,
|
||||
bool kStoreLse_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
bool kPadS_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_splitkv_combine_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr bool kStoreLse = kStoreLse_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_();
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
ck_tile::index_t kTileSizeS_,
|
||||
ck_tile::index_t kTileSizeSk_,
|
||||
ck_tile::index_t kTileSizeD_,
|
||||
ck_tile::index_t kTileSizeDv_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
bool kPadS_,
|
||||
bool kPadSk_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
|
||||
bool kIsPagedKV_>
|
||||
struct fmha_fwd_appendkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
|
||||
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
|
||||
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
|
||||
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSk = kPadSk_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr auto RotaryEnum = RotaryEnum_;
|
||||
static constexpr bool kIsPagedKV = kIsPagedKV_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fmha_fwd_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool has_dropout;
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_splitkv_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
mask_enum mask_type;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_lse;
|
||||
bool do_fp8_static_quant;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
|
||||
fmha_fwd_splitkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
struct fmha_fwd_appendkv_traits
|
||||
{
|
||||
int hdim_q;
|
||||
int hdim_v;
|
||||
std::string data_type;
|
||||
bool is_v_rowmajor;
|
||||
rope_enum rope_type;
|
||||
};
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
const ck_tile::stream_config&);
|
||||
@ -1,157 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
{
|
||||
no_mask = 0,
|
||||
mask_top_left,
|
||||
mask_bottom_right,
|
||||
window_generic,
|
||||
};
|
||||
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
os << "n";
|
||||
else if(type == mask_enum::mask_top_left)
|
||||
os << "t(" << left << ":" << right << ")";
|
||||
else if(type == mask_enum::mask_bottom_right)
|
||||
os << "b(" << left << ":" << right << ")";
|
||||
else
|
||||
{
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
{
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "xt" || t == "xb")
|
||||
{
|
||||
// xformer style sliding window attn from top-left
|
||||
ck_tile::index_t window_size = atoi(v.c_str());
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto found_1 = v.find(",");
|
||||
if(found_1 == std::string::npos)
|
||||
{
|
||||
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
|
||||
assert(0);
|
||||
}
|
||||
tmp.type = mask_enum::window_generic;
|
||||
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
|
||||
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
|
||||
// TODO: some validation
|
||||
if(t == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
|
||||
assert(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto set_causal_top_left = [&]() {
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
};
|
||||
auto set_causal_bottom_right = [&]() {
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
};
|
||||
if(str == "t")
|
||||
set_causal_top_left();
|
||||
else if(str == "b")
|
||||
set_causal_bottom_right();
|
||||
else
|
||||
{
|
||||
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
|
||||
if(tmp.type == mask_enum::mask_top_left)
|
||||
{
|
||||
set_causal_top_left();
|
||||
}
|
||||
else if(tmp.type == mask_enum::mask_bottom_right)
|
||||
{
|
||||
set_causal_bottom_right();
|
||||
}
|
||||
}
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@ -22,6 +22,7 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
|
||||
dtype,
|
||||
false, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
false, // has_logits_soft_cap
|
||||
mask.type,
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse,
|
||||
@ -85,6 +86,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
ck_tile::index_t stride_attn_bias = 0;
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
|
||||
if (attn_bias_.has_value()) {
|
||||
auto a_b = attn_bias_.value();
|
||||
CHECK_DEVICE(a_b);
|
||||
@ -94,7 +96,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
nhead_stride_bias = a_b.stride(1);
|
||||
batch_stride_bias = a_b.stride(0);
|
||||
}
|
||||
|
||||
return fmha_fwd_args{q.data_ptr(),
|
||||
k.data_ptr(),
|
||||
v.data_ptr(),
|
||||
@ -116,6 +117,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
0.0f, // logits_soft_cap
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@ -139,6 +141,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
-1, // min_seqlen_q
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
drop_seed_offset};
|
||||
|
||||
@ -20,6 +20,7 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
|
||||
dtype,
|
||||
true, // is_group_mode
|
||||
true, // is_v_rowmajor
|
||||
false, // has_logits_soft_cap
|
||||
mask.type,
|
||||
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
has_lse,
|
||||
@ -117,6 +118,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
softmax_scale, // scale_s
|
||||
1, // scale_p
|
||||
1, // scale_o
|
||||
0.0f, // logits_soft_cap
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
@ -140,6 +142,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
-1, // min_seqlen_q
|
||||
p_dropout,
|
||||
has_dropout_randval,
|
||||
drop_seed_offset};
|
||||
|
||||
@ -1,84 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <tuple>
|
||||
|
||||
// keep sync with RotaryEmbeddingEnum
|
||||
enum class rope_enum
|
||||
{
|
||||
none = 0,
|
||||
interleaved = 1,
|
||||
half_rotated = 2,
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
generate_rotary_cos_sin(ck_tile::index_t seqlen,
|
||||
ck_tile::index_t rotary_dim,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
// return dummy tensors if we won't apply RoPE at all
|
||||
if(rotary_dim <= 0)
|
||||
{
|
||||
ck_tile::HostTensor<DataType> dummy({1, 1});
|
||||
return std::make_tuple(dummy, dummy);
|
||||
}
|
||||
|
||||
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen * 2;
|
||||
const ck_tile::index_t num_cols = rotary_dim / 2;
|
||||
|
||||
using std::begin, std::end;
|
||||
|
||||
ck_tile::HostTensor<float> angle({num_rows, num_cols});
|
||||
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
|
||||
|
||||
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::cos(origin_value));
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::sin(origin_value));
|
||||
});
|
||||
|
||||
return std::make_tuple(cos, sin);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
|
||||
const ck_tile::HostTensor<DataType>& sin,
|
||||
ck_tile::index_t seqlen_offset,
|
||||
ck_tile::index_t seqlen)
|
||||
{
|
||||
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
|
||||
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
|
||||
|
||||
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen;
|
||||
const ck_tile::index_t num_cols = cos.get_length(1);
|
||||
|
||||
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
|
||||
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
|
||||
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
return std::make_tuple(cos_pt, sin_pt);
|
||||
}
|
||||
@ -5,6 +5,12 @@ import os
|
||||
import sys
|
||||
|
||||
|
||||
# Run only this selected group of models, leave this empty to run everything
|
||||
TORCHBENCH_ONLY_MODELS = [
|
||||
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
|
||||
]
|
||||
|
||||
|
||||
# Note - hf and timm have their own version of this, torchbench does not
|
||||
# TODO(voz): Someday, consolidate all the files into one runner instead of a shim like this...
|
||||
def model_names(filename: str) -> set[str]:
|
||||
@ -17,6 +23,8 @@ def model_names(filename: str) -> set[str]:
|
||||
if len(line_parts) == 1:
|
||||
line_parts = line.split(",")
|
||||
model_name = line_parts[0]
|
||||
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
|
||||
continue
|
||||
names.add(model_name)
|
||||
return names
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import copy
|
||||
import csv
|
||||
import dataclasses
|
||||
import functools
|
||||
import gc
|
||||
import importlib
|
||||
import itertools
|
||||
import json
|
||||
@ -2387,6 +2388,7 @@ class BenchmarkRunner:
|
||||
)
|
||||
|
||||
def warmup(fn, model, example_inputs, mode, niters=10):
|
||||
gc.collect()
|
||||
peak_mem = 0
|
||||
start_stats = get_dynamo_stats()
|
||||
try:
|
||||
@ -2548,6 +2550,7 @@ class BenchmarkRunner:
|
||||
return experiment(*self.maybe_cast(model, example_inputs))
|
||||
|
||||
def warmup(fn, model, example_inputs, mode, niters=5):
|
||||
gc.collect()
|
||||
peak_mem = 0
|
||||
start_stats = get_dynamo_stats()
|
||||
try:
|
||||
|
||||
@ -106,6 +106,11 @@ finally:
|
||||
# on A100 GPUs - 40 GB.
|
||||
BATCH_SIZE_KNOWN_MODELS = {}
|
||||
|
||||
# Run only this selected group of models, leave this empty to run everything
|
||||
TORCHBENCH_ONLY_MODELS = [
|
||||
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
|
||||
]
|
||||
|
||||
|
||||
# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
|
||||
# Get the list of models and their batch sizes
|
||||
@ -116,6 +121,8 @@ with open(MODELS_FILENAME) as fh:
|
||||
lines = [line.rstrip() for line in lines]
|
||||
for line in lines:
|
||||
model_name, batch_size = line.split(",")
|
||||
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
|
||||
continue
|
||||
batch_size = int(batch_size)
|
||||
BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
|
||||
assert len(BATCH_SIZE_KNOWN_MODELS)
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1009000000,0.1
|
||||
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
|
||||
|
||||
|
||||
|
||||
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1
|
||||
basic_NestedModule_eager,compile_time_instruction_count,8787000000,0.1
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -39,13 +39,20 @@ finally:
|
||||
from timm.models import create_model
|
||||
|
||||
TIMM_MODELS = {}
|
||||
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
|
||||
|
||||
# Run only this selected group of models, leave this empty to run everything
|
||||
TORCHBENCH_ONLY_MODELS = [
|
||||
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
|
||||
]
|
||||
|
||||
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
|
||||
with open(filename) as fh:
|
||||
lines = fh.readlines()
|
||||
lines = [line.rstrip() for line in lines]
|
||||
for line in lines:
|
||||
model_name, batch_size = line.split(" ")
|
||||
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
|
||||
continue
|
||||
TIMM_MODELS[model_name] = int(batch_size)
|
||||
|
||||
|
||||
|
||||
@ -224,7 +224,7 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
|
||||
// check if the key is unrecognized.
|
||||
if (device_config_parser_hook_) {
|
||||
TORCH_CHECK(
|
||||
keys_.find(key) != keys_.end(),
|
||||
getKeys().find(key) != getKeys().end(),
|
||||
"Unrecognized key '",
|
||||
key,
|
||||
"' in Accelerator allocator config.");
|
||||
|
||||
@ -220,11 +220,24 @@ class C10_API AcceleratorAllocatorConfig {
|
||||
return instance().last_allocator_settings_;
|
||||
}
|
||||
|
||||
// Use `Construct On First Use Idiom` to avoid `Static Initialization Order`
|
||||
// issue.
|
||||
static std::unordered_set<std::string>& getMutableKeys() {
|
||||
static std::unordered_set<std::string> keys{
|
||||
"max_split_size_mb",
|
||||
"max_non_split_rounding_mb",
|
||||
"garbage_collection_threshold",
|
||||
"roundup_power2_divisions",
|
||||
"expandable_segments",
|
||||
"pinned_use_background_threads"};
|
||||
return keys;
|
||||
}
|
||||
|
||||
// Returns the set of valid keys for the allocator configuration.
|
||||
// This set is used to validate the presence and correctness of keys in
|
||||
// device-specific configuration parsers.
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
return keys_;
|
||||
return getMutableKeys();
|
||||
}
|
||||
|
||||
// Registers a device-specific configuration parser hook and its key. This
|
||||
@ -238,9 +251,10 @@ class C10_API AcceleratorAllocatorConfig {
|
||||
std::function<void(const std::string&)>&& hook,
|
||||
const std::unordered_set<std::string>& keys) {
|
||||
device_config_parser_hook_ = std::move(hook);
|
||||
auto& mutable_keys = getMutableKeys();
|
||||
for (auto& key : keys) {
|
||||
TORCH_CHECK(
|
||||
keys_.insert(key).second,
|
||||
mutable_keys.insert(key).second,
|
||||
"Duplicated key '",
|
||||
key,
|
||||
"' found in device-specific configuration parser hook registration");
|
||||
@ -326,17 +340,6 @@ class C10_API AcceleratorAllocatorConfig {
|
||||
// their own environment configuration extensions.
|
||||
inline static std::function<void(const std::string&)>
|
||||
device_config_parser_hook_{nullptr};
|
||||
|
||||
// A set of valid configuration keys, including both common and
|
||||
// device-specific options. This set is used to validate the presence and
|
||||
// correctness of keys during parsing.
|
||||
inline static std::unordered_set<std::string> keys_{
|
||||
"max_split_size_mb",
|
||||
"max_non_split_rounding_mb",
|
||||
"garbage_collection_threshold",
|
||||
"roundup_power2_divisions",
|
||||
"expandable_segments",
|
||||
"pinned_use_background_threads"};
|
||||
};
|
||||
|
||||
C10_API inline void setAllocatorSettings(const std::string& env) {
|
||||
|
||||
10
c10/core/CachingDeviceAllocator.cpp
Normal file
10
c10/core/CachingDeviceAllocator.cpp
Normal file
@ -0,0 +1,10 @@
|
||||
#include <c10/core/CachingDeviceAllocator.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// Ensures proper DLL export of this pure virtual base class on Windows,
|
||||
// since it's mainly used in other DLLs outside c10.dll.
|
||||
DeviceAllocator::DeviceAllocator() = default;
|
||||
DeviceAllocator::~DeviceAllocator() = default;
|
||||
|
||||
} // namespace c10
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Stream.h>
|
||||
|
||||
namespace c10::CachingDeviceAllocator {
|
||||
|
||||
@ -59,3 +60,55 @@ struct DeviceStats {
|
||||
};
|
||||
|
||||
} // namespace c10::CachingDeviceAllocator
|
||||
|
||||
namespace c10 {
|
||||
|
||||
using CaptureId_t = unsigned long long;
|
||||
|
||||
// first is set if the instance is created by Graph mode capture_begin.
|
||||
// second is set if the instance is created by Graph mode graph_pool_handle.
|
||||
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
|
||||
|
||||
struct C10_API DeviceAllocator : public c10::Allocator {
|
||||
DeviceAllocator();
|
||||
~DeviceAllocator() override;
|
||||
|
||||
// Returns true if the allocator has been properly initialized and is ready
|
||||
// for use
|
||||
virtual bool initialized() = 0;
|
||||
|
||||
// Releases all cached device memory from the specified memory pool back to
|
||||
// the system
|
||||
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
|
||||
|
||||
// Associates a memory allocation with a stream to establish dependency
|
||||
// tracking. Prevents memory reuse until all operations on the specified
|
||||
// stream complete
|
||||
virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0;
|
||||
|
||||
// Retrieves comprehensive memory statistics for the specified device,
|
||||
// including allocation patterns, usage metrics
|
||||
virtual CachingDeviceAllocator::DeviceStats getDeviceStats(
|
||||
c10::DeviceIndex device) = 0;
|
||||
|
||||
// Resets cumulative allocation statistics for the specified device to zero
|
||||
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
|
||||
|
||||
// Resets peak memory usage statistics for the specified device
|
||||
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
|
||||
};
|
||||
|
||||
// This function is used to get the DeviceAllocator for a specific device type
|
||||
// and keep backward compatibility with c10::GetAllocator.
|
||||
C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) {
|
||||
TORCH_CHECK(
|
||||
t != DeviceType::CPU,
|
||||
"getDeviceAllocator is not supported for CPU device type.");
|
||||
auto* allocator = c10::GetAllocator(t);
|
||||
auto* device_allocator = dynamic_cast<DeviceAllocator*>(allocator);
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
device_allocator, "Allocator for ", t, " is not a DeviceAllocator.");
|
||||
return device_allocator;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
||||
@ -191,11 +191,17 @@ class C10_API Scalar {
|
||||
isIntegral() const {
|
||||
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag;
|
||||
}
|
||||
|
||||
bool isIntegral(bool includeBool) const {
|
||||
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag ||
|
||||
(includeBool && isBoolean());
|
||||
}
|
||||
|
||||
// See Note [Meaning of HAS_u]
|
||||
bool isUnsigned() const {
|
||||
return Tag::HAS_u == tag || (Tag::HAS_i == tag && v.i >= 0);
|
||||
}
|
||||
|
||||
bool isComplex() const {
|
||||
return Tag::HAS_z == tag;
|
||||
}
|
||||
|
||||
@ -19,25 +19,16 @@
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <ostream>
|
||||
#include <type_traits>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <torch/headeronly/core/ScalarType.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// dummy struct for uint1 to uint7, actual functionality
|
||||
// of these dtypes will be implemented in python with Tensor subclass
|
||||
template <unsigned int N>
|
||||
struct dummy_uint1_7_t {};
|
||||
|
||||
// dummy struct for int1 to int7, actual functionality
|
||||
// of these dtypes will be implemented in python with Tensor subclass
|
||||
template <unsigned int N>
|
||||
struct dummy_int1_7_t {};
|
||||
|
||||
// For the macros below:
|
||||
// [dtype Macros note] For the macros below:
|
||||
//
|
||||
// For users: If you want to macro some code for all non-QInt scalar types
|
||||
// (i.e. types with complete information, you probably want one of the
|
||||
@ -57,56 +48,6 @@ struct dummy_int1_7_t {};
|
||||
// some old PRs where we added new dtypes (check history of this file) can
|
||||
// help give you an idea where to start.
|
||||
|
||||
// NB: Order matters for this macro; it is relied upon in
|
||||
// _promoteTypesLookup and the serialization format.
|
||||
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
|
||||
_(uint8_t, Byte) /* 0 */ \
|
||||
_(int8_t, Char) /* 1 */ \
|
||||
_(int16_t, Short) /* 2 */ \
|
||||
_(int, Int) /* 3 */ \
|
||||
_(int64_t, Long) /* 4 */ \
|
||||
_(at::Half, Half) /* 5 */ \
|
||||
_(float, Float) /* 6 */ \
|
||||
_(double, Double) /* 7 */ \
|
||||
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
|
||||
_(c10::complex<float>, ComplexFloat) /* 9 */ \
|
||||
_(c10::complex<double>, ComplexDouble) /* 10 */ \
|
||||
_(bool, Bool) /* 11 */ \
|
||||
_(c10::qint8, QInt8) /* 12 */ \
|
||||
_(c10::quint8, QUInt8) /* 13 */ \
|
||||
_(c10::qint32, QInt32) /* 14 */ \
|
||||
_(at::BFloat16, BFloat16) /* 15 */ \
|
||||
_(c10::quint4x2, QUInt4x2) /* 16 */ \
|
||||
_(c10::quint2x4, QUInt2x4) /* 17 */ \
|
||||
_(c10::bits1x8, Bits1x8) /* 18 */ \
|
||||
_(c10::bits2x4, Bits2x4) /* 19 */ \
|
||||
_(c10::bits4x2, Bits4x2) /* 20 */ \
|
||||
_(c10::bits8, Bits8) /* 21 */ \
|
||||
_(c10::bits16, Bits16) /* 22 */ \
|
||||
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
|
||||
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
|
||||
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
|
||||
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
|
||||
_(uint16_t, UInt16) /* 27 */ \
|
||||
_(uint32_t, UInt32) /* 28 */ \
|
||||
_(uint64_t, UInt64) /* 29 */ \
|
||||
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
|
||||
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
|
||||
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
|
||||
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
|
||||
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
|
||||
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
|
||||
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \
|
||||
_(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \
|
||||
_(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \
|
||||
_(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \
|
||||
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
|
||||
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
|
||||
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
|
||||
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
|
||||
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
|
||||
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
|
||||
|
||||
// If you want to support ComplexHalf for real, add ComplexHalf
|
||||
// into this macro (and change the name). But beware: convert()
|
||||
// doesn't work for all the conversions you need...
|
||||
@ -152,17 +93,6 @@ struct dummy_int1_7_t {};
|
||||
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
|
||||
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
|
||||
|
||||
enum class ScalarType : int8_t {
|
||||
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
|
||||
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
|
||||
#undef DEFINE_ENUM_ST_ENUM_VAL_
|
||||
Undefined,
|
||||
NumOptions
|
||||
};
|
||||
|
||||
constexpr uint16_t NumScalarTypes =
|
||||
static_cast<uint16_t>(ScalarType::NumOptions);
|
||||
|
||||
namespace impl {
|
||||
|
||||
// These are used to map ScalarTypes to C++ types.
|
||||
|
||||
@ -110,8 +110,22 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
return instance().m_use_async_allocator;
|
||||
}
|
||||
|
||||
// Use `Construct On First Use Idiom` to avoid `Static Initialization Order`
|
||||
// issue.
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
return keys_;
|
||||
static std::unordered_set<std::string> keys{
|
||||
"backend",
|
||||
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
|
||||
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_cud"
|
||||
"amalloc",
|
||||
"pinned_use_cud"
|
||||
"a_host_register",
|
||||
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_hipmalloc",
|
||||
"pinned_use_hip_host_register",
|
||||
"pinned_num_register_threads"};
|
||||
return keys;
|
||||
}
|
||||
|
||||
static CUDAAllocatorConfig& instance() {
|
||||
@ -163,18 +177,6 @@ class C10_CUDA_API CUDAAllocatorConfig {
|
||||
std::atomic<bool> m_pinned_use_cuda_host_register{false};
|
||||
std::atomic<bool> m_use_async_allocator{false};
|
||||
std::atomic<bool> m_is_allocator_loaded{false};
|
||||
inline static std::unordered_set<std::string> keys_{
|
||||
"backend",
|
||||
// keep BC for Rocm: `cuda` -> `cud` `a`, to avoid hipify issues
|
||||
// NOLINTBEGIN(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_cud"
|
||||
"amalloc",
|
||||
"pinned_use_cud"
|
||||
"a_host_register",
|
||||
// NOLINTEND(bugprone-suspicious-missing-comma,-warnings-as-errors)
|
||||
"release_lock_on_hipmalloc",
|
||||
"pinned_use_hip_host_register",
|
||||
"pinned_num_register_threads"};
|
||||
};
|
||||
|
||||
// Keep this for backwards compatibility
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user