Compare commits

..

1 Commits

Author SHA1 Message Date
f85d332b24 [skip ci] Threading through cute based FA for flex
ghstack-source-id: 87db2561c6d69c1022f9fe249857426a47c0457a
Pull-Request: https://github.com/pytorch/pytorch/pull/159521
2025-07-31 17:41:24 -07:00
370 changed files with 8338 additions and 12615 deletions

View File

@ -144,6 +144,16 @@ case "$tag" in
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9)
CUDA_VERSION=12.6.3
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
;;
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.12
@ -154,6 +164,39 @@ case "$tag" in
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
;;
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.10
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.12
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks)
CUDA_VERSION=12.6
ANACONDA_PYTHON_VERSION=3.13
GCC_VERSION=9
VISION=yes
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
TRITON=yes
INDUCTOR_BENCHMARKS=yes
;;
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9)
CUDA_VERSION=12.8.1
ANACONDA_PYTHON_VERSION=3.10
@ -176,6 +219,18 @@ case "$tag" in
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-py3.11-clang12)
ANACONDA_PYTHON_VERSION=3.11
CLANG_VERSION=12
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-py3.9-gcc9)
ANACONDA_PYTHON_VERSION=3.9
GCC_VERSION=9
VISION=yes
TRITON=yes
;;
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-noble-rocm-n-py3)
if [[ $tag =~ "jammy" ]]; then
ANACONDA_PYTHON_VERSION=3.10

View File

@ -1 +1 @@
f7888497a1eb9e98d4c07537f0d0bcfe180d1363
11ec6354315768a85da41032535e3b7b99c5f706

View File

@ -68,8 +68,8 @@ function install_nvshmem {
# download, unpack, install
wget -q "${url}"
tar xf "${filename}.tar.gz"
cp -a "libnvshmem/include/"* /usr/local/cuda/include/
cp -a "libnvshmem/lib/"* /usr/local/cuda/lib64/
cp -a "libnvshmem/include/"* /usr/local/include/
cp -a "libnvshmem/lib/"* /usr/local/lib/
# cleanup
cd ..

View File

@ -15,37 +15,11 @@ function install_timm() {
commit=$(get_pinned_commit timm)
pip_install "git+https://github.com/huggingface/pytorch-image-models@${commit}"
}
function install_torchbench() {
local commit
commit=$(get_pinned_commit torchbench)
git clone https://github.com/pytorch/benchmark torchbench
pushd torchbench
git checkout "$commit"
python install.py --continue_on_fail
# TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488
# is regressing speedup metric. This needs to be investigated further
pip install transformers==4.38.1
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
popd
chown -R jenkins torchbench
# Clean up
conda_run pip uninstall -y torch torchvision triton
}
# Pango is needed for weasyprint which is needed for doctr
conda_install pango
# Stable packages are ok here, just to satisfy TorchBench check
pip_install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
install_torchbench
install_huggingface
install_timm
# Clean up
conda_run pip uninstall -y torch torchvision torchaudio triton

View File

@ -103,5 +103,5 @@ fi
# It depends on torch and triton. We don't want to install
# triton and torch from production on Docker CI images
if [[ "$ANACONDA_PYTHON_VERSION" != 3.9* ]]; then
pip_install helion --no-deps
pip_install helion==0.0.10 --no-deps
fi

View File

@ -361,6 +361,7 @@ pwlf==2.2.1
#Pinned versions: 2.2.1
#test that import: test_sac_estimator.py
# To build PyTorch itself
pyyaml
pyzstd

View File

@ -1,7 +1,7 @@
sphinx==5.3.0
#Description: This is used to generate PyTorch docs
#Pinned versions: 5.3.0
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@722b7e6f9ca512fcc526ad07d62b3d28c50bb6cd#egg=pytorch_sphinx_theme2
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git@pytorch_sphinx_theme2#egg=pytorch_sphinx_theme2
# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering
# but it doesn't seem to work and hangs around idly. The initial thought that it is probably
@ -50,7 +50,7 @@ IPython==8.12.0
#Pinned versions: 8.12.0
myst-nb==0.17.2
#Description: This is used to generate PyTorch functorch and torch.compile docs.
#Description: This is used to generate PyTorch functorch and torch.compile docs
#Pinned versions: 0.17.2
# The following are required to build torch.distributed.elastic.rendezvous.etcd* docs

View File

@ -98,9 +98,8 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/huggingface.txt huggingface.txt
COPY ci_commit_pins/timm.txt timm.txt
COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt
# (optional) Install non-default Ninja version
ARG NINJA_VERSION

View File

@ -98,9 +98,8 @@ COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/huggingface.txt huggingface.txt
COPY ci_commit_pins/timm.txt timm.txt
COPY ci_commit_pins/torchbench.txt torchbench.txt
RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt torchbench.txt
RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface.txt
ARG TRITON
ARG TRITON_CPU

View File

@ -194,7 +194,7 @@ ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library
ROCBLAS_LIB_DST=lib/rocblas/library
ROCBLAS_ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH)
ROCBLAS_OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx)
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $ROCBLAS_OTHER_FILES)
ROCBLAS_LIB_FILES=($ROCBLAS_ARCH_SPECIFIC_FILES $OTHER_FILES)
# hipblaslt library files
HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library

View File

@ -229,6 +229,7 @@ function install_torchrec_and_fbgemm() {
pip_install tabulate # needed for newer fbgemm
pip_install patchelf # needed for rocm fbgemm
pushd /tmp
local wheel_dir=dist/fbgemm_gpu
local found_whl=0
@ -244,7 +245,7 @@ function install_torchrec_and_fbgemm() {
if [ "${found_whl}" == "0" ]; then
git clone --recursive https://github.com/pytorch/fbgemm
pushd fbgemm/fbgemm_gpu
git checkout "${fbgemm_commit}" --recurse-submodules
git checkout "${fbgemm_commit}"
python setup.py bdist_wheel \
--build-variant=rocm \
-DHIP_ROOT_DIR="${ROCM_PATH}" \
@ -263,6 +264,7 @@ function install_torchrec_and_fbgemm() {
done
rm -rf fbgemm
popd
else
pip_build_and_install "git+https://github.com/pytorch/torchrec.git@${torchrec_commit}" dist/torchrec
pip_build_and_install "git+https://github.com/pytorch/FBGEMM.git@${fbgemm_commit}#subdirectory=fbgemm_gpu" dist/fbgemm_gpu
@ -281,6 +283,30 @@ function clone_pytorch_xla() {
fi
}
function checkout_install_torchbench() {
local commit
commit=$(get_pinned_commit torchbench)
git clone https://github.com/pytorch/benchmark torchbench
pushd torchbench
git checkout "$commit"
if [ "$1" ]; then
python install.py --continue_on_fail models "$@"
else
# Occasionally the installation may fail on one model but it is ok to continue
# to install and test other models
python install.py --continue_on_fail
fi
# TODO (huydhn): transformers-4.44.2 added by https://github.com/pytorch/benchmark/pull/2488
# is regressing speedup metric. This needs to be investigated further
pip install transformers==4.38.1
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
popd
}
function install_torchao() {
local commit
commit=$(get_pinned_commit torchao)

View File

@ -157,29 +157,6 @@ test_jit_hooks() {
assert_git_not_dirty
}
# Shellcheck doesn't like it when you pass no arguments to a function
# that can take args. See https://www.shellcheck.net/wiki/SC2120
# shellcheck disable=SC2120
checkout_install_torchbench() {
local commit
commit=$(cat .ci/docker/ci_commit_pins/torchbench.txt)
git clone https://github.com/pytorch/benchmark torchbench
pushd torchbench
git checkout "$commit"
if [ "$1" ]; then
python install.py --continue_on_fail models "$@"
else
# Occasionally the installation may fail on one model but it is ok to continue
# to install and test other models
python install.py --continue_on_fail
fi
echo "Print all dependencies after TorchBench is installed"
python -mpip freeze
popd
}
torchbench_setup_macos() {
git clone --recursive https://github.com/pytorch/vision torchvision
git clone --recursive https://github.com/pytorch/audio torchaudio
@ -202,6 +179,8 @@ torchbench_setup_macos() {
USE_OPENMP=0 python setup.py develop
popd
# Shellcheck doesn't like it when you pass no arguments to a function that can take args. See https://www.shellcheck.net/wiki/SC2120
# shellcheck disable=SC2119,SC2120
checkout_install_torchbench
}

View File

@ -627,8 +627,6 @@ test_perf_for_dashboard() {
device=cuda_a10g
elif [[ "${TEST_CONFIG}" == *h100* ]]; then
device=cuda_h100
elif [[ "${TEST_CONFIG}" == *b200* ]]; then
device=cuda_b200
elif [[ "${TEST_CONFIG}" == *rocm* ]]; then
device=rocm
fi
@ -803,16 +801,6 @@ test_dynamo_benchmark() {
if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then
test_single_dynamo_benchmark "training" "$suite" "$shard_id" --training --amp "$@"
elif [[ "${TEST_CONFIG}" == *perf* ]]; then
# TODO (huydhn): Just smoke test some sample models
if [[ "${TEST_CONFIG}" == *b200* ]]; then
if [[ "${suite}" == "huggingface" ]]; then
export TORCHBENCH_ONLY_MODELS="DistillGPT2"
elif [[ "${suite}" == "timm_models" ]]; then
export TORCHBENCH_ONLY_MODELS="inception_v3"
elif [[ "${suite}" == "torchbench" ]]; then
export TORCHBENCH_ONLY_MODELS="hf_Bert"
fi
fi
test_single_dynamo_benchmark "dashboard" "$suite" "$shard_id" "$@"
else
if [[ "${TEST_CONFIG}" == *cpu* ]]; then
@ -1684,11 +1672,13 @@ elif [[ "${TEST_CONFIG}" == *timm* ]]; then
elif [[ "${TEST_CONFIG}" == cachebench ]]; then
install_torchaudio
install_torchvision
PYTHONPATH=/torchbench test_cachebench
checkout_install_torchbench nanogpt BERT_pytorch resnet50 hf_T5 llama moco
PYTHONPATH=$(pwd)/torchbench test_cachebench
elif [[ "${TEST_CONFIG}" == verify_cachebench ]]; then
install_torchaudio
install_torchvision
PYTHONPATH=/torchbench test_verify_cachebench
checkout_install_torchbench nanogpt
PYTHONPATH=$(pwd)/torchbench test_verify_cachebench
elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
install_torchaudio
install_torchvision
@ -1697,22 +1687,28 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then
# https://github.com/opencv/opencv-python/issues/885
pip_install opencv-python==4.8.0.74
if [[ "${TEST_CONFIG}" == *inductor_torchbench_smoketest_perf* ]]; then
PYTHONPATH=/torchbench test_inductor_torchbench_smoketest_perf
checkout_install_torchbench hf_Bert hf_Albert timm_vision_transformer
PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_smoketest_perf
elif [[ "${TEST_CONFIG}" == *inductor_torchbench_cpu_smoketest_perf* ]]; then
PYTHONPATH=/torchbench test_inductor_torchbench_cpu_smoketest_perf
checkout_install_torchbench timm_vision_transformer phlippe_densenet basic_gnn_edgecnn \
llama_v2_7b_16h resnet50 timm_efficientnet mobilenet_v3_large timm_resnest \
functorch_maml_omniglot yolov3 mobilenet_v2 resnext50_32x4d densenet121 mnasnet1_0
PYTHONPATH=$(pwd)/torchbench test_inductor_torchbench_cpu_smoketest_perf
elif [[ "${TEST_CONFIG}" == *torchbench_gcp_smoketest* ]]; then
TORCHBENCHPATH=/torchbench test_torchbench_gcp_smoketest
checkout_install_torchbench
TORCHBENCHPATH=$(pwd)/torchbench test_torchbench_gcp_smoketest
else
checkout_install_torchbench
# Do this after checkout_install_torchbench to ensure we clobber any
# nightlies that torchbench may pull in
if [[ "${TEST_CONFIG}" != *cpu* ]]; then
install_torchrec_and_fbgemm
fi
PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id"
PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id"
fi
elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then
install_torchvision
PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER"
if [[ "$SHARD_NUMBER" -eq "1" ]]; then
test_inductor_aoti
fi

View File

@ -1 +1 @@
6fbc710b617f79b992ef2ebc7f95e818aa390293
bf305f538005f2e900f8850ed57146024a8bc559

View File

@ -1 +1 @@
6a39ba85fe0f2fff9494b5eccea717c93510c230
ca9e2be3ed6320b51f52f536595cd24e254f8bb2

View File

@ -1 +1 @@
b6a5b82b9948b610fa4c304d0d869c82b8f17db1
29ae4c76c026185f417a25e841d2cd5e65f087a3

View File

@ -488,10 +488,6 @@
- torch/_dynamo/**
- torch/csrc/dynamo/**
- test/dynamo/**
- test/dynamo_expected_failures/**
- test/dynamo_skips/**
- test/inductor_expected_failures/**
- test/inductor_skips/**
approved_by:
- guilhermeleobas
mandatory_checks_name:

View File

@ -193,7 +193,7 @@ LIBTORCH_CONTAINER_IMAGES: dict[str, str] = {
"cpu": "libtorch-cxx11-builder:cpu",
}
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"]
FULL_PYTHON_VERSIONS = ["3.9", "3.10", "3.11", "3.12", "3.13", "3.13t"]
def translate_desired_cuda(gpu_arch_type: str, gpu_arch_version: str) -> str:
@ -315,11 +315,6 @@ def generate_wheels_matrix(
# TODO: Enable python 3.13t on cpu-s390x
if gpu_arch_type == "cpu-s390x" and python_version == "3.13t":
continue
# TODO: Enable python 3.14 on non linux OSes
if os != "linux" and (
python_version == "3.14" or python_version == "3.14t"
):
continue
if use_split_build and (
arch_version not in ["12.6", "12.8", "12.9", "cpu"] or os != "linux"

View File

@ -96,7 +96,7 @@ jobs:
steps:
- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
if: ${{ !contains(matrix.runner, 'b200') && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
if: ${{ matrix.runner != 'B200' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
instructions: |
@ -109,7 +109,7 @@ jobs:
no-sudo: true
- name: Setup Python
if: contains(matrix.runner, 'b200')
if: matrix.runner == 'B200'
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: '3.12'
@ -117,7 +117,7 @@ jobs:
- name: Setup Linux
uses: ./.github/actions/setup-linux
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && !contains(matrix.runner, 'b200')
if: inputs.build-environment != 'linux-s390x-binary-manywheel' && matrix.runner != 'B200'
- name: configure aws credentials
if: ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
@ -128,7 +128,7 @@ jobs:
aws-region: us-east-1
- name: Login to Amazon ECR
if: ${{ inputs.aws-role-to-assume != '' && contains(matrix.runner, 'b200') }}
if: ${{ inputs.aws-role-to-assume != '' && matrix.runner == 'B200' }}
id: login-ecr
continue-on-error: true
uses: aws-actions/amazon-ecr-login@062b18b96a7aff071d4dc91bc00c4c1a7945b076 # v2.0.1
@ -166,17 +166,17 @@ jobs:
uses: pytorch/test-infra/.github/actions/setup-nvidia@main
with:
driver-version: ${{ matrix.config == 'legacy_nvidia_driver' && '525.105.17' || '570.133.07' }}
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && !contains(matrix.runner, 'b200') }}
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' && matrix.runner != 'B200' }}
- name: Setup GPU_FLAG for docker run
id: setup-gpu-flag
run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}"
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || contains(matrix.runner, 'b200')) }}
if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && (steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' || matrix.runner == 'B200') }}
- name: Setup SCCACHE_SERVER_PORT environment for docker run when on container
id: setup-sscache-port-flag
run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}"
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && !contains(matrix.runner, 'b200') }}
if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && matrix.runner != 'B200' }}
- name: Lock NVIDIA A100 40GB Frequency
run: |
@ -277,8 +277,8 @@ jobs:
NO_TD: ${{ steps.keep-going.outputs.ci-no-td }}
TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }}
# Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs
SCCACHE_BUCKET: ${{ !contains(matrix.runner, 'b200') && 'ossci-compiler-cache-circleci-v2' || '' }}
SCCACHE_REGION: ${{ !contains(matrix.runner, 'b200') && 'us-east-1' || '' }}
SCCACHE_BUCKET: ${{ matrix.runner != 'B200' && 'ossci-compiler-cache-circleci-v2' || '' }}
SCCACHE_REGION: ${{ matrix.runner != 'B200' && 'us-east-1' || '' }}
SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }}
DOCKER_IMAGE: ${{ inputs.docker-image }}
XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }}
@ -403,7 +403,7 @@ jobs:
job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }}
- name: Authenticate with AWS
if: ${{ contains(matrix.runner, 'b200') }}
if: ${{ matrix.runner == 'B200' }}
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results

View File

@ -34,8 +34,7 @@ jobs:
contents: read
pull-requests: write
name: Check labels
# Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved
if: github.repository_owner == 'pytorch' && false
if: github.repository_owner == 'pytorch'
runs-on: linux.24_04.4x
steps:
- name: Checkout PyTorch

View File

@ -7,8 +7,7 @@ on:
jobs:
ghstack-mergeability-check:
# Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved
if: github.repository_owner == 'pytorch' && false
if: github.repository_owner == 'pytorch'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

View File

@ -51,12 +51,17 @@ jobs:
docker-image-name: [
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc11-vllm,
pytorch-linux-jammy-cuda12.6-cudnn9-py3-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.6-cudnn9-py3.12-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.6-cudnn9-py3.13-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.12-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.13-gcc9-inductor-benchmarks,
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
pytorch-linux-jammy-py3.9-clang12,
pytorch-linux-jammy-py3.11-clang12,
pytorch-linux-jammy-py3.12-clang12,
pytorch-linux-jammy-py3.13-clang12,
pytorch-linux-jammy-rocm-n-py3,
pytorch-linux-noble-rocm-n-py3,
@ -71,8 +76,7 @@ jobs:
pytorch-linux-jammy-py3-clang12-onnx,
pytorch-linux-jammy-linter,
pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter,
# Executorch pin needs update
# pytorch-linux-jammy-py3-clang12-executorch,
pytorch-linux-jammy-py3-clang12-executorch,
pytorch-linux-jammy-py3.12-triton-cpu
]
include:

File diff suppressed because it is too large Load Diff

View File

@ -1,154 +0,0 @@
name: inductor-perf-b200
on:
schedule:
- cron: 0 7 * * 1-6
- cron: 0 7 * * 0
# NB: GitHub has an upper limit of 10 inputs here, so before we can sort it
# out, let try to run torchao cudagraphs_low_precision as part of cudagraphs
workflow_dispatch:
inputs:
training:
description: Run training (on by default)?
required: false
type: boolean
default: true
inference:
description: Run inference (on by default)?
required: false
type: boolean
default: true
default:
description: Run inductor_default?
required: false
type: boolean
default: false
dynamic:
description: Run inductor_dynamic_shapes?
required: false
type: boolean
default: false
cppwrapper:
description: Run inductor_cpp_wrapper?
required: false
type: boolean
default: false
cudagraphs:
description: Run inductor_cudagraphs?
required: false
type: boolean
default: true
freezing_cudagraphs:
description: Run inductor_cudagraphs with freezing for inference?
required: false
type: boolean
default: false
aotinductor:
description: Run aot_inductor for inference?
required: false
type: boolean
default: false
maxautotune:
description: Run inductor_max_autotune?
required: false
type: boolean
default: false
benchmark_configs:
description: The list of configs used the benchmark
required: false
type: string
default: inductor_huggingface_perf_cuda_b200,inductor_timm_perf_cuda_b200,inductor_torchbench_perf_cuda_b200
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf
build:
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
# Use a bigger runner here because CUDA_ARCH 9.0 is only built for H100
# or newer GPUs, so it doesn't benefit much from existing compiler cache
# from trunk. Also use a memory-intensive runner here because memory is
# usually the bottleneck
runner: linux.12xlarge.memory
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9-inductor-benchmarks
cuda-arch-list: '10.0'
test-matrix: |
{ include: [
{ config: "inductor_huggingface_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
{ config: "inductor_timm_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
{ config: "inductor_torchbench_perf_cuda_b200", shard: 1, num_shards: 1, runner: "linux.dgx.b200" },
]}
selected-test-configs: ${{ inputs.benchmark_configs }}
build-additional-packages: "vision audio fbgemm torchao"
secrets: inherit
test-periodically:
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '0 7 * * 1-6'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
timeout-minutes: 720
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
test-weekly:
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
if: github.event.schedule == '0 7 * * 0'
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
timeout-minutes: 1440
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit
test:
name: cuda12.8-py3.10-gcc9-sm100
uses: ./.github/workflows/_linux-test.yml
needs: build
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc9-sm100
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.build.outputs.docker-image }}
test-matrix: ${{ needs.build.outputs.test-matrix }}
aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only
timeout-minutes: 720
disable-monitor: false
monitor-log-interval: 15
monitor-data-collect-interval: 4
secrets: inherit

View File

@ -81,21 +81,21 @@ jobs:
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
{ config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
{ config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.mi300.2" },
]}
secrets: inherit

View File

@ -75,11 +75,10 @@ jobs:
repo-owner: pytorch
branch: main
pin-folder: .github/ci_commit_pins
# executorch jobs are disabled since it needs some manual work for the hash update
# - repo-name: executorch
# repo-owner: pytorch
# branch: main
# pin-folder: .ci/docker/ci_commit_pins
- repo-name: executorch
repo-owner: pytorch
branch: main
pin-folder: .ci/docker/ci_commit_pins
- repo-name: triton
repo-owner: triton-lang
branch: main

View File

@ -51,6 +51,37 @@ jobs:
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-jammy-cuda12_4-py3_10-gcc11-sm89-build:
name: linux-jammy-cuda12.4-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_4-py3_10-gcc11-sm89-test:
name: linux-jammy-cuda12.4-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_4-py3_10-gcc11-sm89-build
- target-determination
with:
build-environment: linux-jammy-cuda12.4-py3.10-gcc11-sm89
docker-image: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_4-py3_10-gcc11-sm89-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-cuda12_4-py3_10-gcc11-build:
name: linux-jammy-cuda12.4-py3.10-gcc11
uses: ./.github/workflows/_linux-build.yml

View File

@ -292,14 +292,13 @@ jobs:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" },
]}
secrets: inherit
@ -403,8 +402,38 @@ jobs:
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-sm89-build:
name: linux-jammy-cuda12.8-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
cuda-arch-list: 8.9
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu" },
]}
secrets: inherit
linux-jammy-cuda12_8-py3_10-gcc11-sm89-test:
name: linux-jammy-cuda12.8-py3.10-gcc11-sm89
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-jammy-cuda12_8-py3_10-gcc11-sm89-build
- target-determination
with:
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm89
docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-sm89-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3-clang12-executorch-build:
if: false # Docker build needs pin update
name: linux-jammy-py3-clang12-executorch
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type

View File

@ -10,10 +10,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true
permissions:
id-token: write
contents: read
jobs:
get-default-label-prefix:
if: github.repository_owner == 'pytorch'

View File

@ -205,7 +205,7 @@ jobs:
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-jammy-py3.9-gcc11
docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11-inductor-benchmarks
docker-image-name: ci-image:pytorch-linux-jammy-py3.9-gcc11
test-matrix: |
{ include: [
{ config: "verify_cachebench", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },

View File

@ -23,7 +23,7 @@ jobs:
with:
repository: pytorch/pytorch
stable-branch: viable/strict
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\", \"linux-aarch64\"]'
requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\"]'
secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }}
clickhouse-url: ${{ secrets.CLICKHOUSE_URL }}
clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }}

View File

@ -1,18 +1 @@
- This is the only AGENTS.md, there are no recursive AGENTS.md
- When you are working on a bug, first create a standalone file that
reproduces the bug and verify it fails in the expected way. Use this to
test if your changes work. Once the change is passing, find an appropriate
test file to add the test to and make sure to follow local conventions on
the test file.
- If you are running the real test suite, DO NOT run the entire test suite.
Instead run only a single test case, e.g., 'python test/test_torch.py TestTorch.test_dir'
- Do NOT run setup.py, you do not have a working build environment
- Do NOT run pre-commit, it is not setup
- To run lint, run 'lintrunner -a' (which will autoapply changes). lintrunner
ONLY accepts this flag, do not try to run on individual files.
- Do NOT attempt to install dependencies, you do not have Internet access
- When you are ready to make a PR, do exactly these steps:
- git stash -u
- git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch
- git stash pop
- Resolve conflicts if necessary

View File

@ -14,6 +14,7 @@
/torch/csrc/autograd/ @albanD @soulitzer
/torch/autograd/ @albanD @soulitzer
/tools/autograd/ @albanD @soulitzer
/torch/header_only_apis.txt @janeyx99
/torch/nn/ @albanD @jbschlosser @mikaylagawarecki
/torch/optim/ @albanD @janeyx99
/test/test_public_bindings.py @albanD
@ -195,8 +196,3 @@ torch/backends/cudnn/ @eqy @syed-ahmed
/torch/utils/_cxx_pytree.py @XuehaiPan
/torch/utils/pytree/ @XuehaiPan
/torch/_dynamo/polyfills/pytree.py @XuehaiPan
# Relating to libtorch ABI
/torch/csrc/stable/ @janeyx99 @mikaylagawarecki
/torch/headeronly/ @janeyx99
/torch/header_only_apis.txt @janeyx99

View File

@ -276,7 +276,7 @@ conda install pkg-config libuv
pip install mkl-static mkl-include
# Add these packages if torch.distributed is needed.
# Distributed package support on Windows is a prototype feature and is subject to changes.
conda install -c conda-forge libuv
conda install -c conda-forge libuv=1.39
```
#### Install PyTorch

View File

@ -439,7 +439,6 @@ if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
_pytorch_rocm_generate_ck_conf()
@ -704,17 +703,21 @@ if(USE_MPS)
if(CAN_COMPILE_METAL)
foreach(SHADER ${native_mps_metal})
cmake_path(GET SHADER STEM TGT_STEM)
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air")
string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air")
list(APPEND AIR_BASIC ${TGT_BASIC})
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1")
list(APPEND AIR_BFLOAT ${TGT_BFLOAT})
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0")
metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1")
endforeach()
air_to_metallib(kernels_basic.metallib ${AIR_BASIC})
air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT})
add_custom_command(
COMMAND echo "// $$(date)" > metallib_dummy.cpp
DEPENDS kernels_basic.metallib
DEPENDS kernels_basic.metallib kernels_bfloat.metallib
OUTPUT metallib_dummy.cpp
COMMENT "Updating metallibs timestamp")
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp)
else()
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
foreach(SHADER ${native_mps_metal})

View File

@ -1,6 +1,5 @@
#pragma once
#include <c10/core/CachingDeviceAllocator.h>
#include <c10/core/DeviceType.h>
#include <c10/macros/Macros.h>
@ -73,27 +72,6 @@ TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index);
// original device index that was active before the change.
TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index);
TORCH_API inline void emptyCache() {
const auto device_type = getAccelerator(true).value();
at::getDeviceAllocator(device_type)->emptyCache();
}
TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
return at::getDeviceAllocator(device_type)->getDeviceStats(device_index);
}
TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index);
}
TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
}
} // namespace at::accelerator
namespace at {

View File

@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAGraph.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/Functions.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAFunctions.h>
#include <cstddef>

View File

@ -2,7 +2,6 @@
#include <ATen/Tensor.h>
#include <c10/core/Device.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAGraphsC10Utils.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/util/flat_hash_map.h>

View File

@ -24,29 +24,6 @@ static void _assert_match(const O& original, const C& compared, const std::strin
}
}
template<>
void _assert_match<c10::Device, std::optional<c10::Device>>(
const c10::Device& original,
const std::optional<c10::Device>& compared,
const std::string& name) {
if (compared) {
const c10::Device& expected = compared.value();
if (original.type() != expected.type()) {
std::stringstream msg;
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
throw std::runtime_error(msg.str());
}
// If the expected device doesn't have an index (e.g., just "cuda"),
// or if both devices have the same index, consider them equal
if (expected.has_index() && original.has_index() && expected.index() != original.index()) {
std::stringstream msg;
msg << "Tensor " << name << " mismatch! Expected: " << expected << ", Got: " << original;
throw std::runtime_error(msg.str());
}
}
}
void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional<c10::ScalarType> dtype, std::optional<c10::Device> device, std::optional<c10::Layout> layout) {
_assert_match(tensor.sym_sizes(), sizes, "sizes");
_assert_match(tensor.sym_strides(), strides, "strides");

View File

@ -367,27 +367,27 @@ void int8pack_mm_kernel_(
auto* C_data = C.data_ptr<T>();
const auto* S_data = scales.const_data_ptr<T>();
int64_t M = A.size(0);
int64_t N = B.size(0);
int64_t K = A.size(1);
int64_t lda = A.stride(0);
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 4;
int M = A.size(0);
int N = B.size(0);
int K = A.size(1);
int lda = A.stride(0);
constexpr int BLOCK_M = 4;
constexpr int BLOCK_N = 4;
const int64_t MB = (M + BLOCK_M - 1) / BLOCK_M;
const int64_t NB = (N + BLOCK_N - 1) / BLOCK_N;
const int MB = (M + BLOCK_M - 1) / BLOCK_M;
const int NB = (N + BLOCK_N - 1) / BLOCK_N;
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
at::parallel_for(0, MB * NB, 0, [&](int begin, int end) {
int mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
for (const auto i : c10::irange(begin, end)) {
(void)i;
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(BLOCK_M, M - mb_start);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(BLOCK_N, N - nb_start);
const auto* A_ptr = A_data + mb_start * lda;
const auto* B_ptr = B_data + nb_start * K;

View File

@ -526,7 +526,7 @@ namespace {
// we are dealing with packed tensor here. max index is the same as numel.
// TODO: to really support input tensor large enough to go beyond int32,
// TODO: to really support input tensor large enought to go beyond int32,
// we will need to restrict out shared memory usage and adjust the launch
// config;
AT_ASSERT(input_.numel() < std::numeric_limits<int32_t>::max());
@ -681,7 +681,7 @@ namespace {
const dim3 grid(grid_x, grid_y, grid_z);
// we are dealing with packed tensor here. max index is the same as numel.
// TODO: to really support input tensor large enough to go beyond int32,
// TODO: to really support input tensor large enought to go beyond int32,
// we will need to restrict out shared memory usage and adjust the launch
// config;
AT_ASSERT(input.numel() < std::numeric_limits<int32_t>::max());

View File

@ -1634,9 +1634,6 @@ bool use_fast_accum) {
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
if (!a_is_2d || !b_is_2d) {
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
TORCH_CHECK(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
@ -1719,9 +1716,6 @@ std::optional<c10::ScalarType> out_dtype) {
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
if (!a_is_2d || !b_is_2d) {
TORCH_CHECK(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
// check that the strides are valid, the fn will throw an error if not
check_valid_strides_and_return_transposed(mat_a);

View File

@ -223,7 +223,7 @@ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bo
class CuFFTConfig {
public:
// Only move semantics is enough for this class. Although we already use
// Only move semantics is enought for this class. Although we already use
// unique_ptr for the plan, still remove copy constructor and assignment op so
// we don't accidentally copy and take perf hit.
CuFFTConfig(const CuFFTConfig&) = delete;

View File

@ -241,8 +241,6 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
Strides tensor_StrideA = make_strides(mat_a.strides());
Strides tensor_StrideB = make_strides(mat_b.strides());
Strides tensor_StrideOutput = make_strides(out.strides());
Strides tensor_ShapeA = make_strides(mat_a.sizes());
Strides tensor_ShapeB = make_strides(mat_b.sizes());
at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>(
reinterpret_cast<DtypeA*>(mat_a.data_ptr()),
@ -266,8 +264,6 @@ void bf16bf16_grouped_gemm_impl_sm90_sm100(
tensor_StrideA,
tensor_StrideB,
tensor_StrideOutput,
tensor_ShapeA,
tensor_ShapeB,
0,
0,
a_row_major,

View File

@ -38,20 +38,18 @@ __global__ void prepare_grouped_gemm_data(
Strides tensor_StrideA,
Strides tensor_StrideB,
Strides tensor_StrideOutput,
Strides tensor_ShapeA,
Strides tensor_ShapeB,
int64_t a_scale_stride,
int64_t b_scale_stride,
bool a_row_major = true,
bool b_row_major = false) {
int32_t tid = threadIdx.x;
int32_t delta = 0;
int32_t offset = 0;
if (offs != nullptr) {
int32_t start = tid == 0 ? 0 : offs[tid - 1];
offset = offs[tid];
delta = offset - start;
CUDA_KERNEL_ASSERT(delta >=0 && "expected gemm dimension to be greater or equal 0\n");
delta = offs[tid] - start;
if (K < 0) {
CUDA_KERNEL_ASSERT(delta >=0 && "expected ofsets to be greater or equal 0\n");
}
// TMA transfers require global memory tensor addresses to be
// aligned to 16 bytes.
@ -86,7 +84,6 @@ __global__ void prepare_grouped_gemm_data(
int64_t lda, ldb, ldoutput;
if (M < 0) {
// A and output is 2d
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[0] && "expected offset to be less than tensor size\n");
M = delta;
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];
ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2];
@ -99,7 +96,6 @@ __global__ void prepare_grouped_gemm_data(
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput;
B_ptrs[tid] = B + tid * tensor_StrideB[0];
} else if (N < 0) {
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeB[1] && "expected offset to be less than tensor size\n");
N = delta;
lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2];
ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed
@ -112,7 +108,6 @@ __global__ void prepare_grouped_gemm_data(
inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1];
}
} else if (K < 0) {
CUDA_KERNEL_ASSERT(offset <= tensor_ShapeA[1] && offset <= tensor_ShapeB[0] && "expected offset to be less than tensor size\n");
// A, B is 2d, output is 3d
K = delta;
lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1];

View File

@ -644,12 +644,7 @@ Tensor ctc_loss_backward_gpu_template(const Tensor& grad_out, const Tensor& log_
Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta))
// As above, there may be better configurations to use.
constexpr int max_threads_ = std::is_same_v<scalar_t, float> ? 1024 : 896; // we need 72 or so 32 bit registers for double
int max_threads = max_threads_;
// Blackwell launch bounds
if (at::cuda::getCurrentDeviceProperties()->major >= 10) {
max_threads = 512;
}
constexpr int max_threads = std::is_same_v<scalar_t, float> ? 1024 : 896; // we need 72 or so 32 bit registers for double
int threads_target = max_threads;
while (threads_target / 2 >= 2*max_target_length+1) {
threads_target /= 2;

View File

@ -298,9 +298,6 @@ void f8f8bf16_grouped_gemm_impl_sm90(
Strides tensor_StrideA = make_strides(mat_a.strides());
Strides tensor_StrideB = make_strides(mat_b.strides());
Strides tensor_StrideOutput = make_strides(out.strides());
Strides tensor_ShapeA = make_strides(mat_a.sizes());
Strides tensor_ShapeB = make_strides(mat_b.sizes());
// scale stride will be used inside the kernel only if needed,
// so for 1d scales the "1" assigned here won't be used
int64_t a_scale_stride = scale_a.stride(0);
@ -328,8 +325,6 @@ void f8f8bf16_grouped_gemm_impl_sm90(
tensor_StrideA,
tensor_StrideB,
tensor_StrideOutput,
tensor_ShapeA,
tensor_ShapeB,
a_scale_stride,
b_scale_stride);

View File

@ -1,74 +0,0 @@
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace at::native {
__global__ void weight_int8pack_mm_kernel(const float* x, const int8_t* w, const float* scale, float* out, int B, int K, int N) {
// one thread per output element: [B, N]
int b = blockIdx.y * blockDim.y + threadIdx.y;
int n = blockIdx.x * blockDim.x + threadIdx.x;
if (b >= B || n >= N) return;
float acc = 0.0f;
for (int k = 0; k < K; ++k) {
acc += x[b * K + k] * static_cast<float>(w[n * K + k]);
}
out[b * N + n] = acc * scale[n];
}
void launch_weight_int8pack_mm_cuda_kernel(const Tensor& x, const Tensor& w_int8, const Tensor& scale, Tensor& out) {
const int B = x.size(0);
const int K = x.size(1);
const int N = w_int8.size(0);
const dim3 block(16, 16);
const dim3 grid((N + block.x - 1) / block.x, (B + block.y - 1) / block.y);
auto stream = at::cuda::getCurrentCUDAStream();
weight_int8pack_mm_kernel<<<grid, block, 0, stream>>>(
x.data_ptr<float>(),
w_int8.data_ptr<int8_t>(),
scale.data_ptr<float>(),
out.data_ptr<float>(),
B, K, N);
}
// Main GPU entry point
at::Tensor _weight_int8pack_mm_cuda(const at::Tensor& x, const at::Tensor& w_int8, const at::Tensor& scale) {
// --- Check inputs ---
TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor");
TORCH_CHECK(w_int8.is_cuda(), "w must be a CUDA tensor");
TORCH_CHECK(scale.is_cuda(), "scale must be a CUDA tensor");
TORCH_CHECK(x.dim() == 2, "x must be 2D");
TORCH_CHECK(w_int8.dim() == 2, "w must be 2D");
TORCH_CHECK(scale.dim() == 1, "scale must be 1D");
TORCH_CHECK(x.size(1) == w_int8.size(1), "K dimension mismatch: x.size(1) != w.size(1)");
TORCH_CHECK(w_int8.size(0) == scale.size(0), "Output dim mismatch: w.size(0) != scale.size(0)");
// --- Determine shapes ---
auto B = x.size(0); // batch size
auto N = w_int8.size(0); // output dim
// Ensure inputs are in the correct types for the kernel
auto x_f32 = x.to(at::kFloat);
auto w_int8_contiguous = w_int8.contiguous();
auto scale_f32 = scale.to(at::kFloat);
// --- Allocate output ---
auto out = at::empty({B, N}, x.options().dtype(at::kFloat));
// --- Launch kernel ---
launch_weight_int8pack_mm_cuda_kernel(x_f32, w_int8_contiguous, scale_f32, out);
return out;
}
} // namespace at::native

View File

@ -28,22 +28,6 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
TORCH_CHECK(false, "cudnn_batch_norm: ATen not compiled with cuDNN support");
}
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
const Tensor& input,
const Tensor& weight,
const std::optional<Tensor>& bias,
const std::optional<Tensor>& running_mean,
const std::optional<Tensor>& running_var,
bool training,
double exponential_average_factor,
double epsilon,
Tensor& out,
Tensor& save_mean,
Tensor& save_var,
Tensor& reserve) {
AT_ERROR("cudnn_batch_norm_out: ATen not compiled with cuDNN support");
}
std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
const Tensor& input,
const Tensor& grad_output,
@ -136,12 +120,7 @@ size_t _get_cudnn_batch_norm_reserve_space_size(
return reserve_size;
}
// Param `reserve` is a placeholder, just passing an empty tensor.
// usage:
// auto reserve = torch::empty({0}, torch::device(torch::kCUDA));
// at::native::cudnn_batch_norm_out(..., epsilon, output, save_mean, save_var,
// reserve);
std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_t_opt,
@ -149,11 +128,7 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
const std::optional<Tensor>& running_var_t_opt,
bool training,
double exponential_average_factor,
double epsilon,
Tensor& output_t,
Tensor& save_mean,
Tensor& save_var,
Tensor& reserve) {
double epsilon) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_t_maybe_owned =
at::borrow_from_optional_tensor(bias_t_opt);
@ -193,6 +168,9 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
cudnnBatchNormMode_t mode = getCudnnBatchNormMode(
training, input->suggest_memory_format(), input->dim());
auto output_t =
at::empty_like(*input, input->options(), input->suggest_memory_format());
TensorArg output{output_t, "output", 0};
auto handle = getCudnnHandle();
@ -204,8 +182,15 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
Constant one(dataType, 1);
Constant zero(dataType, 0);
Tensor save_mean, save_var;
Tensor reserve;
if (training) {
int64_t num_features = input_t.size(1);
save_mean = at::empty({num_features}, weight_t.options());
save_var = at::empty({num_features}, weight_t.options());
auto op = CUDNN_BATCHNORM_OPS_BN;
size_t workspace_size;
AT_CUDNN_CHECK(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
@ -253,6 +238,9 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
reserve_size));
} else {
reserve = at::empty({0}, input->options().dtype(kByte));
// This keeps a consistent output with native_batch_norm
save_mean = at::empty({0}, weight_t.options());
save_var = at::empty({0}, weight_t.options());
AT_CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
handle,
mode,
@ -273,48 +261,10 @@ std::tuple<Tensor&, Tensor&, Tensor&, Tensor&> cudnn_batch_norm_out(
// save_mean and save_var can be undefined
// If this causes problems, we can initialize them to empty tensors
// of the correct type
return std::tuple<Tensor&, Tensor&, Tensor&, Tensor&>{
return std::tuple<Tensor, Tensor, Tensor, Tensor>{
output_t, save_mean, save_var, reserve};
}
std::tuple<Tensor, Tensor, Tensor, Tensor> cudnn_batch_norm(
const Tensor& input_t,
const Tensor& weight_t,
const std::optional<Tensor>& bias_t_opt,
const std::optional<Tensor>& running_mean_t_opt,
const std::optional<Tensor>& running_var_t_opt,
bool training,
double exponential_average_factor,
double epsilon) {
auto output_t = at::empty_like(
input_t, input_t.options(), input_t.suggest_memory_format());
Tensor save_mean, save_var, reserve;
if (training) {
int64_t num_features = input_t.size(1);
save_mean = at::empty({num_features}, weight_t.options());
save_var = at::empty({num_features}, weight_t.options());
} else {
// This keeps a consistent output with native_batch_norm
save_mean = at::empty({0}, weight_t.options());
save_var = at::empty({0}, weight_t.options());
}
return cudnn_batch_norm_out(
input_t,
weight_t,
bias_t_opt,
running_mean_t_opt,
running_var_t_opt,
training,
exponential_average_factor,
epsilon,
output_t,
save_mean,
save_var,
reserve);
}
// NB: CuDNN only implements the backward algorithm for batchnorm
// in training mode (evaluation mode batchnorm has a different algorithm),
// which is why this doesn't accept a 'training' parameter.

View File

@ -1,6 +1,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Config.h>
#include <ATen/Context.h>
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/mkldnn/Matmul.h>
@ -427,74 +428,56 @@ static inline bool checksize(const Tensor& mat1, const Tensor& mat2){
}
}
bool use_mkldnn_bf16_matmul(
template <typename T>
bool use_mkldnn_typed_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
bool dtype_check = false;
if constexpr (std::is_same_v<T, c10::BFloat16>) {
#if defined(__aarch64__)
if (mkldnn_bf16_device_check_arm()) {
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
// inputs, allow it for float as well
return (
use_mkldnn_bf16_matmul() &&
(mat1.scalar_type() == mat2.scalar_type()) &&
(!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
} else
if (mkldnn_bf16_device_check_arm()) {
// onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g.
// Arm Neoverse V1 so, don't restrict the mkldnn_matmul only for bf16
// inputs, allow it for float as well
dtype_check = use_mkldnn_bf16_matmul() &&
((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16));
}
#else
dtype_check = dtype_check && use_mkldnn_bf16_matmul() &&
(mat1.scalar_type() == kBFloat16);
#endif
{
return (
use_mkldnn_bf16_matmul() && mat1.scalar_type() == kBFloat16 &&
mat2.scalar_type() == kBFloat16 &&
(!result.defined() || result.scalar_type() == kBFloat16) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
} else if constexpr (std::is_same_v<T, c10::Half>) {
dtype_check = dtype_check && use_mkldnn_fp16_matmul() &&
(mat1.scalar_type() == kHalf);
} else if constexpr (std::is_same_v<T, float>) {
dtype_check = dtype_check &&
(use_mkldnn_bf32_matmul() || use_mkldnn_tf32_matmul()) &&
(mat1.scalar_type() == kFloat);
}
}
bool use_mkldnn_fp16_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (
use_mkldnn_fp16_matmul() && mat1.scalar_type() == kHalf &&
mat2.scalar_type() == kHalf &&
(!result.defined() || result.scalar_type() == kHalf) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
}
bool use_mkldnn_bf32_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (
use_mkldnn_bf32_matmul() && mat1.scalar_type() == kFloat &&
mat2.scalar_type() == kFloat &&
(!result.defined() || result.scalar_type() == kFloat) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
}
bool use_mkldnn_tf32_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (
use_mkldnn_tf32_matmul() && mat1.scalar_type() == kFloat &&
mat2.scalar_type() == kFloat &&
(!result.defined() || result.scalar_type() == kFloat) &&
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2));
if (!dtype_check) {
return false;
}
bool size_check =
mat1.numel() != 0 && mat2.numel() != 0 && checksize(mat1, mat2);
dtype_check = (mat1.scalar_type() == mat2.scalar_type()) &&
(!result.defined() || result.scalar_type() == mat1.scalar_type());
return dtype_check && size_check;
}
bool use_mkldnn_matmul(
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
return (
use_mkldnn_bf16_matmul(mat1, mat2, result) ||
use_mkldnn_fp16_matmul(mat1, mat2, result) ||
use_mkldnn_bf32_matmul(mat1, mat2, result) ||
use_mkldnn_tf32_matmul(mat1, mat2, result));
auto mat1_type = mat1.scalar_type();
if (mat1_type != kBFloat16 || mat1_type != kHalf || mat1_type != kFloat) {
return false;
}
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, mat1.scalar_type(), "use_mkldnn_matmul", [&] {
return use_mkldnn_typed_matmul<scalar_t>(mat1, mat2, result);
});
return false;
}
static void _mkldnn_matmul_i8i8i32_with_primitive(

View File

@ -469,94 +469,4 @@ Tensor _weight_int4pack_mm_xpu(
return C;
}
Tensor& _int_mm_out_xpu(
const Tensor& self,
const Tensor& mat2,
Tensor& result) {
TORCH_CHECK(
self.dim() == 2,
"Expected self to be of dimension 2 but got ",
self.dim());
TORCH_CHECK(
mat2.dim() == 2,
"Expected mat2 to be of dimension 2 but got ",
mat2.dim());
TORCH_CHECK(
self.size(1) == mat2.size(0),
"self.size(1) needs to match mat2.size(0) but got ",
self.size(1),
" and ",
mat2.size(0));
TORCH_CHECK(
self.dtype() == at::kChar,
"Expected self dtype to be of type int8 but got ",
self.dtype());
TORCH_CHECK(
mat2.dtype() == at::kChar,
"Expected mat2 dtype to be of type int8 but got ",
mat2.dtype());
TORCH_CHECK(
result.dtype() == at::kInt,
"Expected result dtype to be of type kInt but got ",
result.dtype());
TORCH_CHECK(
result.size(0) == self.size(0),
"Expected result.size(0) to be ",
self.size(0),
" but got ",
result.size(0));
TORCH_CHECK(
result.size(1) == mat2.size(1),
"Expected result.size(1) to be ",
mat2.size(1),
" but got ",
result.size(1));
TORCH_CHECK(
result.dim() == 2,
"Expected result to be of dimension 2 but got ",
result.dim());
TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
if (result.numel() == 0 || self.size(1) == 0) {
return result.zero_();
}
Tensor bias = at::Tensor();
Tensor mat2_scales = at::ones({1}, mat2.options().dtype(at::kFloat));
Tensor mat2_zero_points = at::Tensor();
auto post_op_args = torch::List<std::optional<at::Scalar>>();
at::native::onednn::quantized_matmul(
self.contiguous(),
1.0,
0,
mat2.contiguous(),
mat2_scales,
mat2_zero_points,
bias,
result,
1.0,
0,
result.scalar_type(),
/*other*/ std::nullopt,
/*other scale*/ 1.0,
/*other zp*/ 0,
/*binary post op*/ "none",
/*binary alpha*/ 1.0,
/*post_op_name*/ "none",
post_op_args,
/*post_op_algorithm*/ "none",
/*m2_trans*/ true);
return result;
}
Tensor _int_mm_xpu(const Tensor& self, const Tensor& mat2) {
Tensor result =
at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
return _int_mm_out_xpu(self, mat2, result);
}
} // namespace at::native

View File

@ -953,7 +953,8 @@ class BundledShaderLibary : public MetalShaderLibrary {
if (C10_UNLIKELY(!library)) {
auto device = MPSDevice::getInstance()->device();
NSError* error = nil;
library = [device newLibraryWithData:getSectionData("metal_basic") error:&error];
auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic";
library = [device newLibraryWithData:getSectionData(section_name) error:&error];
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
}
return library;

View File

@ -33,15 +33,21 @@ struct shrink_backward_functor {
REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float);
REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
#endif
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
#endif
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
#endif
struct hardsigmoid_functor {
template <typename T>
@ -61,11 +67,15 @@ struct hardsigmoid_backward_functor {
REGISTER_UNARY_OP(hardsigmoid, float, float);
REGISTER_UNARY_OP(hardsigmoid, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
#endif
REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
#endif
struct hardswish_functor {
template <typename T>
@ -93,11 +103,15 @@ struct hardswish_backward_functor {
REGISTER_UNARY_OP(hardswish, float, float);
REGISTER_UNARY_OP(hardswish, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_OP(hardswish, bfloat, bfloat);
#endif
REGISTER_BINARY_OP(hardswish_backward, float, float);
REGISTER_BINARY_OP(hardswish_backward, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
#endif
struct leaky_relu_functor {
template <typename T>
@ -121,8 +135,12 @@ struct leaky_relu_backward_functor {
REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float);
REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat);
#endif
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float);
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
#endif

View File

@ -113,12 +113,18 @@ kernel void ampUpdateScale(
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float);
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat);
#endif
INSTANTIATE_AMP_UPDATE_SCALE(float);
INSTANTIATE_AMP_UPDATE_SCALE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_UPDATE_SCALE(bfloat);
#endif
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float);
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat);
#endif

View File

@ -590,7 +590,9 @@ kernel void attention(
INSTANTIATE_SDPA_VECTOR_HEADS(float);
INSTANTIATE_SDPA_VECTOR_HEADS(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
#endif
#define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \
template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \
@ -619,4 +621,6 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
INSTANTIATE_ATTN_SHAPES_HELPER(float);
INSTANTIATE_ATTN_SHAPES_HELPER(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_ATTN_SHAPES_HELPER(bfloat);
#endif

View File

@ -209,9 +209,38 @@ struct hermite_polynomial_he_functor {
};
struct nextafter_functor {
#if __METAL_VERSION__ < 310
template <typename U>
struct bit_type {};
template <>
struct bit_type<float> {
using type = int;
};
template <>
struct bit_type<half> {
using type = short;
};
#endif
template <typename T>
inline T operator()(const T a, const T b) {
#if __METAL_VERSION__ >= 310
return static_cast<T>(::metal::nextafter(a, b));
#else
using U = typename bit_type<T>::type;
if (a == b) {
return a;
}
if (::metal::isunordered(a, b)) {
return NAN;
}
if (a == 0) {
constexpr auto eps = as_type<T>(static_cast<U>(1));
return b > 0 ? eps : -eps;
}
auto bits = as_type<U>(a);
(a > 0) ^ (a > b) ? bits++ : bits--;
return as_type<T>(bits);
#endif
}
};
@ -315,6 +344,13 @@ struct fmod_functor {
}
};
// Some helper defines
#if __METAL_VERSION__ >= 310
#define _METAL_310_PLUS(x) x
#else
#define _METAL_310_PLUS(x)
#endif
#define REGISTER_INTEGER_BINARY_OP(NAME) \
REGISTER_BINARY_OP(NAME, long, long); \
REGISTER_BINARY_OP(NAME, int, int); \
@ -334,12 +370,12 @@ struct fmod_functor {
#define REGISTER_FLOAT_BINARY_OP(NAME) \
REGISTER_BINARY_OP(NAME, float, float); \
REGISTER_BINARY_OP(NAME, half, half); \
REGISTER_BINARY_OP(NAME, bfloat, bfloat)
_METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat))
#define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \
REGISTER_OPMATH_BINARY_OP(NAME, float, float); \
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
_METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat))
REGISTER_FLOAT_BINARY_OP(copysign);
REGISTER_INT2FLOAT_BINARY_OP(copysign);
@ -411,9 +447,11 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
#endif
// Complex binary functions
REGISTER_BINARY_OP(polar, float, float2);

View File

@ -180,8 +180,10 @@ REGISTER_SEARCHSORTED_OP(float, int);
REGISTER_SEARCHSORTED_OP(float, long);
REGISTER_SEARCHSORTED_OP(half, int);
REGISTER_SEARCHSORTED_OP(half, long);
#if __METAL_VERSION__ >= 310
REGISTER_SEARCHSORTED_OP(bfloat, int);
REGISTER_SEARCHSORTED_OP(bfloat, long);
#endif
REGISTER_SEARCHSORTED_OP(char, int);
REGISTER_SEARCHSORTED_OP(char, long);
REGISTER_SEARCHSORTED_OP(uchar, int);

View File

@ -96,4 +96,6 @@ kernel void col2im_kernel(
INSTANTIATE_COL2IM(bool);
INSTANTIATE_COL2IM(float);
INSTANTIATE_COL2IM(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_COL2IM(bfloat);
#endif

View File

@ -20,7 +20,9 @@ REGISTER_CROSS_FUNC(short);
REGISTER_CROSS_FUNC(char);
REGISTER_CROSS_FUNC(uchar);
REGISTER_CROSS_FUNC(bool);
#if __METAL_VERSION__ >= 310
REGISTER_CROSS_FUNC(bfloat);
#endif
template <typename T, typename U>
kernel void cross(
@ -66,4 +68,6 @@ REGISTER_CROSS_OP(short);
REGISTER_CROSS_OP(char);
REGISTER_CROSS_OP(uchar);
REGISTER_CROSS_OP(bool);
#if __METAL_VERSION__ >= 310
REGISTER_CROSS_OP(bfloat);
#endif

View File

@ -1,9 +1,11 @@
#include <metal_stdlib>
using metal::max;
#if __METAL_VERSION__ >= 310
bfloat max(bfloat a, bfloat b) {
return a > b ? a : b;
}
#endif
#define kmaxThreadGroups 32
#define kmaxTensors 32
@ -304,9 +306,11 @@ REGISTER_ADAM_OPS_QUART(float, float);
REGISTER_ADAM_OPS_QUART(float, half);
REGISTER_ADAM_OPS_QUART(half, float);
REGISTER_ADAM_OPS_QUART(half, half);
#if __METAL_VERSION__ >= 310
REGISTER_ADAM_OPS_QUART(float, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, float);
#endif
template <typename T>
inline void sgd_momentum_math(
@ -456,5 +460,7 @@ REGISTER_FUSED_SGD_OP(float);
REGISTER_FUSED_SGD_OP(half);
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_FUSED_SGD_OP(bfloat);
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
#endif

View File

@ -106,7 +106,9 @@ kernel void polygamma(
constant int64_t& order [[buffer(2)]], \
uint id [[thread_position_in_grid]]);
#if __METAL_VERSION__ >= 310
INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat);
#endif
INSTANTIATE_GAMMA_KERNELS(half, half);
INSTANTIATE_GAMMA_KERNELS(float, float);
INSTANTIATE_GAMMA_KERNELS(bool, float);

View File

@ -76,4 +76,6 @@ INSTANTIATE_IM2COL(float);
INSTANTIATE_IM2COL(float2);
INSTANTIATE_IM2COL(half);
INSTANTIATE_IM2COL(half2);
#if __METAL_VERSION__ >= 310
INSTANTIATE_IM2COL(bfloat);
#endif

View File

@ -240,7 +240,9 @@ REGISTER_INDEX_OP(put_accumulate, short, short);
REGISTER_INDEX_OP(put_accumulate, char, char);
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
REGISTER_INDEX_OP(put_accumulate, bool, bool);
#if __METAL_VERSION__ >= 310
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
#endif
template <typename StridesT, typename DataT>
kernel void kernel_index_offsets(
@ -475,8 +477,10 @@ INSTANTIATE_INDEX_COPY(char, long);
INSTANTIATE_INDEX_COPY(uchar, int);
INSTANTIATE_INDEX_COPY(uchar, long);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INDEX_COPY(bfloat, int);
INSTANTIATE_INDEX_COPY(bfloat, long);
#endif
INSTANTIATE_INDEX_COPY(float2, int);
INSTANTIATE_INDEX_COPY(float2, long);
INSTANTIATE_INDEX_COPY(half2, int);

View File

@ -288,6 +288,7 @@ kernel void layer_norm_looped(
#define instantiate_layer_norm(DTYPE) \
instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE)
instantiate_layer_norm(float);
instantiate_layer_norm(half);
instantiate_layer_norm(bfloat);
instantiate_layer_norm(float) instantiate_layer_norm(half)
#if __METAL_VERSION__ >= 310
instantiate_layer_norm(bfloat)
#endif

View File

@ -635,7 +635,9 @@ kernel void applyPivots(
INSTANTIATE_NAIVE_MM(float);
INSTANTIATE_NAIVE_MM(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_NAIVE_MM(bfloat);
#endif
// Integral MM
INSTANTIATE_NAIVE_MM(short);

View File

@ -48,14 +48,3 @@ struct PoolingBackwardParams {
::c10::metal::array<idx_type_t, N> grad_output_strides;
::c10::metal::array<idx_type_t, N> indices_strides;
};
template <unsigned N = 5, typename idx_type_t = int32_t>
struct MaxUnpoolingParams {
int32_t dims;
int32_t pooling_dims;
::c10::metal::array<idx_type_t, N> input_sizes;
::c10::metal::array<idx_type_t, N> input_strides;
::c10::metal::array<idx_type_t, N> output_sizes;
::c10::metal::array<idx_type_t, N> output_strides;
::c10::metal::array<idx_type_t, N> indices_strides;
};

View File

@ -168,16 +168,6 @@ PoolOffsets find_pool_offsets(
leading_dims,
return_indices,
tid);
case 3:
return find_pool_offsets_dim_specific<3>(
output_sizes,
output_strides,
indices_strides,
input_strides,
pooling_dim_indices,
leading_dims,
return_indices,
tid);
}
return PoolOffsets();
}
@ -302,68 +292,6 @@ kernel void max_pool_backward(
pooling_dims);
}
template <typename T>
void max_unpool_impl(
device T* output,
T input_element,
int32_t input_index,
constant int32_t* output_sizes,
constant int32_t* output_strides,
int32_t pooling_dims) {
int32_t size_prod = 1;
int32_t pool_offset = 0;
for (auto dim = pooling_dims - 1; dim >= 0; dim--) {
auto next_size_prod = output_sizes[dim] * size_prod;
pool_offset +=
output_strides[dim] * ((input_index % next_size_prod) / size_prod);
size_prod *= output_sizes[dim];
}
output[pool_offset] = input_element;
}
// Kernel computes one element of the grad input per kernel call.
template <typename T>
kernel void max_unpool(
device T* output [[buffer(0)]],
constant T* input [[buffer(1)]],
constant int64_t* indices [[buffer(2)]],
constant MaxUnpoolingParams<5>& params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto input_sizes = params.input_sizes.data();
auto input_strides = params.input_strides.data();
auto output_sizes = params.output_sizes.data();
auto output_strides = params.output_strides.data();
auto indices_strides = params.indices_strides.data();
auto leading_dims = dims - pooling_dims;
// NOTE: Since we're doing unpooling, the variable names "input" and "output"
// are reversed compared to the pooling operations. So in `find_pool_offsets`,
// we need to map "input" -> "output" and "output" -> "input".
PoolOffsets offsets = find_pool_offsets(
/*output_sizes=*/input_sizes,
/*output_strides=*/input_strides,
indices_strides,
/*input_strides=*/output_strides,
/*pooling_dim_indices=*/nullptr,
dims,
leading_dims,
/*return_indices=*/true,
tid);
max_unpool_impl<T>(
output + offsets.input_leading,
input[offsets.output],
indices[offsets.indices],
output_sizes + leading_dims,
output_strides + leading_dims,
pooling_dims);
}
template <typename T>
struct AvgPoolIterBounds {
T start;
@ -430,6 +358,7 @@ void avg_pool_3d_input_iter(
auto divisor = has_divisor_override
? divisor_override
: (bounds0.count) * (bounds1.count) * (bounds2.count);
auto size12 = input_sizes[1] * input_sizes[2];
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
auto offset0 = input_strides[0] * i0;
@ -447,64 +376,6 @@ void avg_pool_3d_input_iter(
*output = value_sum / static_cast<T>(divisor);
}
template <typename T>
void avg_pool_backward_3d_input_iter(
device AtomicType_t<T>* grad_input,
constant T* grad_output,
constant int32_t* grad_input_sizes,
constant int32_t* grad_input_strides,
int32_t grad_input_leading_offset,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
bool count_include_pad,
bool has_divisor_override,
int32_t divisor_override) {
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds2 = get_avg_pool_input_iter_bounds<2>(
grad_input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto divisor = has_divisor_override
? divisor_override
: (bounds0.count) * (bounds1.count) * (bounds2.count);
auto grad_val = *grad_output / static_cast<T>(divisor);
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
auto offset0 = grad_input_strides[0] * i0;
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
auto offset1 = grad_input_strides[1] * i1;
for (auto i2 = bounds2.start; i2 < bounds2.end; i2++) {
auto offset2 = grad_input_strides[2] * i2;
auto pool_offset = offset0 + offset1 + offset2;
AtomicType<T>::atomic_add(
grad_input, grad_input_leading_offset + pool_offset, grad_val);
}
}
}
}
// Kernel computes one element of the output per kernel call.
template <typename T>
kernel void avg_pool(
@ -557,97 +428,31 @@ kernel void avg_pool(
params.divisor_override);
}
template <typename T>
kernel void avg_pool_backward(
device AtomicType_t<T>* grad_input [[buffer(0)]],
constant T* grad_output [[buffer(1)]],
constant AvgPoolingParams<5>& params [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto grad_input_sizes = params.input_sizes.data();
auto grad_input_strides = params.input_strides.data();
auto grad_output_sizes = params.output_sizes.data();
auto grad_output_strides = params.output_strides.data();
auto kernel_size = params.kernel_size.data();
auto stride = params.stride.data();
auto padding = params.padding.data();
auto leading_dims = dims - pooling_dims;
// This buffer keeps track of the pooling dimension indices of this thread's
// element of the output. We need to fill it with the proper values below.
int32_t pooling_dim_indices[3];
PoolOffsets offsets = find_pool_offsets(
grad_output_sizes,
grad_output_strides,
/*indices_strides=*/nullptr,
grad_input_strides,
pooling_dim_indices,
dims,
leading_dims,
/*return_indices=*/false,
tid);
grad_output += offsets.output;
grad_input_sizes += leading_dims;
grad_input_strides += leading_dims;
avg_pool_backward_3d_input_iter<T>(
grad_input,
grad_output,
grad_input_sizes,
grad_input_strides,
offsets.input_leading,
pooling_dim_indices,
kernel_size,
stride,
padding,
params.count_include_pad,
params.has_divisor_override,
params.divisor_override);
}
#define REGISTER_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant PoolingParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("max_unpool_" #DTYPE)]] kernel void max_unpool<DTYPE>( \
device DTYPE * output [[buffer(0)]], \
constant DTYPE * input [[buffer(1)]], \
constant int64_t* indices [[buffer(2)]], \
constant MaxUnpoolingParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
#define REGISTER_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant PoolingParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_" #DTYPE)]] kernel void avg_pool<DTYPE>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_POOL_BACKWARD_OP(DTYPE) \
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
template [[host_name("max_pool_backward_" #DTYPE)]] \
kernel void max_pool_backward<DTYPE>( \
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output_ [[buffer(1)]], \
constant int64_t* grad_indices_ [[buffer(2)]], \
constant PoolingBackwardParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]); \
\
template [[host_name("avg_pool_backward_" #DTYPE)]] \
kernel void avg_pool_backward<DTYPE>( \
device AtomicType_t<DTYPE> * grad_input [[buffer(0)]], \
constant DTYPE * grad_output [[buffer(1)]], \
constant AvgPoolingParams<5> & params [[buffer(2)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_POOL_OP(float);
REGISTER_POOL_OP(half);
REGISTER_POOL_OP(bfloat);
REGISTER_POOL_OP(int);
REGISTER_POOL_OP(long);
REGISTER_POOL_OP(short);
@ -655,6 +460,10 @@ REGISTER_POOL_OP(char);
REGISTER_POOL_OP(uchar);
REGISTER_POOL_OP(bool);
REGISTER_POOL_BACKWARD_OP(float);
REGISTER_POOL_BACKWARD_OP(half);
REGISTER_POOL_BACKWARD_OP(bfloat);
REGISTER_MAX_POOL_BACKWARD_OP(float);
REGISTER_MAX_POOL_BACKWARD_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_POOL_OP(bfloat);
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
#endif

View File

@ -197,10 +197,12 @@ INSTANTIATE_INT4MV(float, 128);
INSTANTIATE_INT4MV(half, 128);
INSTANTIATE_INT4MV(float, 256);
INSTANTIATE_INT4MV(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT4MV(bfloat, 32);
INSTANTIATE_INT4MV(bfloat, 64);
INSTANTIATE_INT4MV(bfloat, 128);
INSTANTIATE_INT4MV(bfloat, 256);
#endif
// ------------------------------ int8 MM For M >= 12 ------------------------------------
/**
@ -232,10 +234,12 @@ template <> struct BlockType<half> {
using simdgroup_type8x8 = simdgroup_half8x8;
using type4 = half4;
};
#if __METAL_VERSION__ >= 310
template <> struct BlockType<bfloat> {
using simdgroup_type8x8 = simdgroup_bfloat8x8;
using type4 = bfloat4;
};
#endif
template<typename T>
float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) {
@ -486,7 +490,9 @@ kernel void kernel_mul_mm<DTYPE, WDTYPE, DEQUANT_FUNC>( \
INSTANTIATE_MM(float, char, get_scale_zero_q8);
INSTANTIATE_MM(half, char, get_scale_zero_q8);
#if __METAL_VERSION__ >= 310
INSTANTIATE_MM(bfloat, char, get_scale_zero_q8);
#endif
// ------------------------------ int8 MM For M < 12 ------------------------------------
/* Matrix vector multiplication, used for small M size for matrix multiplication as well.
@ -640,4 +646,6 @@ kernel void kernel_mul_mv<DTYPE>(
INSTANTIATE_MV(float);
INSTANTIATE_MV(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_MV(bfloat);
#endif

View File

@ -192,4 +192,6 @@ template <typename T>
instantiate_rms(float)
instantiate_rms(half)
#if __METAL_VERSION__ >= 310
instantiate_rms(bfloat)
#endif // clang-format on

View File

@ -23,4 +23,6 @@ kernel void renorm(
REGISTER_RENORM_OP(float);
REGISTER_RENORM_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_RENORM_OP(bfloat);
#endif

View File

@ -25,6 +25,379 @@ struct LogAddExp {
};
};
#if __METAL_VERSION__ < 310
template <typename T, typename acc_t = accum_t<T>>
struct CumMinOp {
static acc_t apply(acc_t a, acc_t b) {
return metal::min(a, b);
}
static acc_t identity() {
return static_cast<acc_t>(
metal::is_floating_point_v<T> ? metal::numeric_limits<T>::infinity()
: metal::numeric_limits<T>::max());
}
};
template <typename T, typename acc_t = accum_t<T>>
struct CumMaxOp {
static acc_t apply(acc_t a, acc_t b) {
return metal::max(a, b);
}
static acc_t identity() {
return static_cast<acc_t>(
metal::is_floating_point_v<T> ? -metal::numeric_limits<T>::infinity()
: metal::numeric_limits<T>::lowest());
}
};
template <typename T, typename acc_t = accum_t<T>>
struct LogCumSumExpOp {
static acc_t apply(acc_t x, acc_t y) {
return LogAddExp{}(x, y);
}
static acc_t identity() {
return -metal::numeric_limits<acc_t>::infinity();
}
};
// Inclusive scan along innermost dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_contiguous_innermost_dim(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant uint& num_rows [[buffer(2)]],
constant uint& row_size [[buffer(3)]],
uint row [[thread_position_in_grid]]) {
if (row >= num_rows)
return;
const uint offset = row * row_size;
acc_t accumulator = Op::identity();
for (uint col = 0; col < row_size; col++) {
T val = input[offset + col];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[offset + col] = static_cast<T>(accumulator);
}
}
// Inclusive scan along outer dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_contiguous_outer_dim(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant uint& num_orows [[buffer(2)]],
constant uint& num_irows [[buffer(3)]],
constant uint& row_size [[buffer(4)]],
uint thread_index [[thread_position_in_grid]]) {
const uint orow = thread_index / num_irows;
const uint irow = thread_index % num_irows;
if (orow >= num_orows)
return;
acc_t accumulator = Op::identity();
const uint idx_base = orow * row_size * num_irows + irow;
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
T val = input[idx];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[idx] = static_cast<T>(accumulator);
}
}
// Inclusive scan with indices along innermost dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_contiguous_innermost_dim(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant uint& num_rows [[buffer(3)]],
constant uint& row_size [[buffer(4)]],
uint row [[thread_position_in_grid]]) {
if (row >= num_rows)
return;
const uint offset = row * row_size;
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
for (uint col = 0; col < row_size; col++) {
T val = input[offset + col];
acc_t accum_val = static_cast<acc_t>(val);
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = col;
}
values[offset + col] = static_cast<T>(accumulator);
indices[offset + col] = best_idx;
}
}
// Inclusive scan with indices along outer dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_contiguous_outer_dim(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant uint& num_orows [[buffer(3)]],
constant uint& num_irows [[buffer(4)]],
constant uint& row_size [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
const uint orow = thread_index / num_irows;
const uint irow = thread_index % num_irows;
if (orow >= num_orows)
return;
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
const uint idx_base = orow * row_size * num_irows + irow;
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
T val = input[idx];
acc_t accum_val = static_cast<acc_t>(val);
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = col;
}
values[idx] = static_cast<T>(accumulator);
indices[idx] = best_idx;
}
}
// Shared utility functions for strided kernels
inline long calculate_non_scan_elements(
constant long* sizes,
uint ndim,
uint scan_dim) {
long total = 1;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
total *= sizes[i];
}
}
return total;
}
inline void thread_index_to_coordinates(
uint index,
int pos[c10::metal::max_ndim],
constant long* sizes,
uint ndim,
uint scan_dim) {
long remaining_index = index;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
pos[i] = remaining_index % sizes[i];
remaining_index /= sizes[i];
} else {
pos[i] = 0;
}
}
}
inline long calculate_base_offset(
int pos[c10::metal::max_ndim],
constant long* strides,
uint ndim,
uint scan_dim) {
long offset = 0;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
offset += pos[i] * strides[i];
}
}
return offset;
}
// Generic strided scan kernel
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_strided(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant long* sizes [[buffer(2)]],
constant long* input_strides [[buffer(3)]],
constant long* output_strides [[buffer(4)]],
constant uint& ndim [[buffer(5)]],
constant uint& scan_dim [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
const long total_non_scan_elements =
calculate_non_scan_elements(sizes, ndim, scan_dim);
if (thread_index >= total_non_scan_elements) {
return;
}
int pos[c10::metal::max_ndim];
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
const long input_base_offset =
calculate_base_offset(pos, input_strides, ndim, scan_dim);
const long output_base_offset =
calculate_base_offset(pos, output_strides, ndim, scan_dim);
acc_t accumulator = Op::identity();
const long scan_size = sizes[scan_dim];
const long input_scan_stride = input_strides[scan_dim];
const long output_scan_stride = output_strides[scan_dim];
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
const long output_offset =
output_base_offset + scan_idx * output_scan_stride;
T val = input[input_offset];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[output_offset] = static_cast<T>(accumulator);
}
}
// Generic strided scan with indices kernel
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_strided(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant long* sizes [[buffer(3)]],
constant long* input_strides [[buffer(4)]],
constant long* values_strides [[buffer(5)]],
constant long* indices_strides [[buffer(6)]],
constant uint& ndim [[buffer(7)]],
constant uint& scan_dim [[buffer(8)]],
uint thread_index [[thread_position_in_grid]]) {
const long total_non_scan_elements =
calculate_non_scan_elements(sizes, ndim, scan_dim);
if (thread_index >= total_non_scan_elements) {
return;
}
int pos[c10::metal::max_ndim];
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
const long input_base_offset =
calculate_base_offset(pos, input_strides, ndim, scan_dim);
const long values_base_offset =
calculate_base_offset(pos, values_strides, ndim, scan_dim);
const long indices_base_offset =
calculate_base_offset(pos, indices_strides, ndim, scan_dim);
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
const long scan_size = sizes[scan_dim];
const long input_scan_stride = input_strides[scan_dim];
const long values_scan_stride = values_strides[scan_dim];
const long indices_scan_stride = indices_strides[scan_dim];
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
const long values_offset =
values_base_offset + scan_idx * values_scan_stride;
const long indices_offset =
indices_base_offset + scan_idx * indices_scan_stride;
T val = input[input_offset];
acc_t accum_val = static_cast<acc_t>(val);
if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = scan_idx;
}
values[values_offset] = static_cast<T>(accumulator);
indices[indices_offset] = best_idx;
}
}
#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
scan_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant uint & num_rows [[buffer(2)]], \
constant uint & row_size [[buffer(3)]], \
uint row [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
scan_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant uint & num_orows [[buffer(2)]], \
constant uint & num_irows [[buffer(3)]], \
constant uint & row_size [[buffer(4)]], \
uint thread_index [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
scan_strided<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant long* sizes [[buffer(2)]], \
constant long* input_strides [[buffer(3)]], \
constant long* output_strides [[buffer(4)]], \
constant uint& ndim [[buffer(5)]], \
constant uint& scan_dim [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
scan_with_indices_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant uint& num_rows [[buffer(3)]], \
constant uint& row_size [[buffer(4)]], \
uint row [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
scan_with_indices_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant uint& num_orows [[buffer(3)]], \
constant uint& num_irows [[buffer(4)]], \
constant uint& row_size [[buffer(5)]], \
uint thread_index [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
scan_with_indices_strided<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant long* sizes [[buffer(3)]], \
constant long* input_strides [[buffer(4)]], \
constant long* values_strides [[buffer(5)]], \
constant long* indices_strides [[buffer(6)]], \
constant uint& ndim [[buffer(7)]], \
constant uint& scan_dim [[buffer(8)]], \
uint thread_index [[thread_position_in_grid]]);
// Simple scan operations
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float);
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half);
// Scan operations with indices
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
#else // __METAL_VERSION__ >= 310
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
// The reminder of this file contains cummin and cummax implementations adapted
@ -786,3 +1159,5 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4);
#endif

View File

@ -89,4 +89,6 @@ REGISTER_SPECIAL(short, float);
REGISTER_SPECIAL(int, float);
REGISTER_SPECIAL(long, float);
REGISTER_SPECIAL(half, half);
#if __METAL_VERSION__ >= 310
REGISTER_SPECIAL(bfloat, bfloat);
#endif

View File

@ -100,7 +100,9 @@ kernel void triul(
INSTANTIATE_TRIUL_KERNELS(float, int);
INSTANTIATE_TRIUL_KERNELS(half, int);
#if __METAL_VERSION__ >= 310
INSTANTIATE_TRIUL_KERNELS(bfloat, int);
#endif
INSTANTIATE_TRIUL_KERNELS(float2, int);
INSTANTIATE_TRIUL_KERNELS(half2, int);

View File

@ -556,9 +556,11 @@ REGISTER_UNARY_OP(abs, half, half);
REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0)
#if __METAL_VERSION__ >= 310
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
REGISTER_UNARY_OP(neg, bfloat, bfloat);
REGISTER_UNARY_OP(abs, bfloat, bfloat);
#endif
INSTANTIATE_UNARY_KERNELS2(half, half);
INSTANTIATE_UNARY_KERNELS2(float, float);
INSTANTIATE_UNARY_KERNELS2(float, bool);
@ -598,4 +600,6 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float);
REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float);
REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat);
#endif

View File

@ -70,4 +70,6 @@ kernel void unfold_backward(
INSTANTIATE_UNFOLD_BACKWARD(float);
INSTANTIATE_UNFOLD_BACKWARD(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_UNFOLD_BACKWARD(bfloat);
#endif

View File

@ -852,4 +852,6 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar);
INSTANTIATE_UPSAMPLE_3D(uchar);
INSTANTIATE_UPSAMPLE_ALL(float);
INSTANTIATE_UPSAMPLE_ALL(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_UPSAMPLE_ALL(bfloat);
#endif

View File

@ -14,7 +14,6 @@
#include <ATen/ops/avg_pool2d_backward.h>
#include <ATen/ops/avg_pool2d_backward_native.h>
#include <ATen/ops/avg_pool2d_native.h>
#include <ATen/ops/avg_pool3d_backward_native.h>
#include <ATen/ops/avg_pool3d_native.h>
#include <ATen/ops/max_pool2d_backward_native.h>
#include <ATen/ops/max_pool2d_native.h>
@ -22,8 +21,6 @@
#include <ATen/ops/max_pool2d_with_indices_native.h>
#include <ATen/ops/max_pool3d_with_indices_backward_native.h>
#include <ATen/ops/max_pool3d_with_indices_native.h>
#include <ATen/ops/max_unpool2d_native.h>
#include <ATen/ops/max_unpool3d_native.h>
#endif
namespace at::native {
@ -495,60 +492,6 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
});
}
static void max_unpool_out_mps_template(const Tensor& input,
const Tensor& indices,
IntArrayRef output_size_,
IntArrayRef stride,
IntArrayRef padding,
Tensor& output,
const int32_t pooling_dims,
const std::string& op_name) {
auto dims = input.dim();
auto leading_dims = input.dim() - pooling_dims;
const auto memory_format = input.suggest_memory_format();
std::vector<int64_t> output_size(dims);
for (int dim : c10::irange(leading_dims)) {
output_size[dim] = input.sizes()[dim];
}
for (int dim : c10::irange(pooling_dims)) {
output_size[leading_dims + dim] = output_size_[dim];
}
output.resize_(output_size, memory_format);
output.fill_(0);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = input.numel();
MaxUnpoolingParams<5> params;
params.dims = dims;
params.pooling_dims = pooling_dims;
for (const auto dim : c10::irange(dims)) {
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
}
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto PSO = lib.getPipelineStateForFunc("max_unpool_" + scalarToMetalTypeString(input));
getMPSProfiler().beginProfileKernel(PSO, op_name, {input});
[computeEncoder setComputePipelineState:PSO];
mtl_setArgs(computeEncoder, output, input, indices, params);
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
getMPSProfiler().endProfileKernel(PSO);
}
});
}
static void avg_pool2d_template(const Tensor& input,
const Tensor& output,
const std::optional<Tensor>& grad_output_opt,
@ -726,64 +669,6 @@ static void avg_pool_out_mps_template(const Tensor& output,
});
}
static void avg_pool_backward_out_mps_template(const Tensor& grad_input,
const Tensor& input,
const Tensor& grad_output,
IntArrayRef _kernel_size,
IntArrayRef _stride,
IntArrayRef _padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const int32_t pooling_dims,
const std::string& op_name) {
auto [dims, _, kernel_size, stride, padding, __] =
process_pool_sizes(input, _kernel_size, _stride, _padding, std::nullopt, ceil_mode, pooling_dims, op_name);
const auto memory_format = input.suggest_memory_format();
grad_input.resize_(input.sizes(), memory_format);
grad_input.fill_(0);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const auto numThreads = grad_output.numel();
AvgPoolingParams<5> params;
params.dims = dims;
params.pooling_dims = pooling_dims;
params.count_include_pad = count_include_pad;
params.has_divisor_override = divisor_override.has_value();
if (divisor_override.has_value()) {
params.divisor_override = safe_downcast<int32_t, int64_t>(divisor_override.value());
}
for (const auto dim : c10::irange(dims)) {
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_output.size(dim));
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(grad_output.stride(dim));
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_input.size(dim));
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(grad_input.stride(dim));
}
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto PSO = lib.getPipelineStateForFunc("avg_pool_backward_" + scalarToMetalTypeString(input));
getMPSProfiler().beginProfileKernel(PSO, op_name, {grad_output});
[computeEncoder setComputePipelineState:PSO];
mtl_setArgs(computeEncoder, grad_input, grad_output, params);
mtl_dispatch1DJob(computeEncoder, PSO, numThreads);
getMPSProfiler().endProfileKernel(PSO);
}
});
}
} // namespace mps
Tensor mps_max_pool2d(const Tensor& input,
@ -1011,68 +896,6 @@ Tensor max_pool3d_with_indices_backward_mps(const Tensor& grad_output,
return grad_input;
}
Tensor& max_unpooling2d_forward_out_mps(const Tensor& self,
const Tensor& indices,
IntArrayRef output_size,
Tensor& output) {
mps::max_unpool_out_mps_template(self,
indices,
output_size,
/*stride=*/{},
/*padding=*/{},
output,
/*pooling_dims=*/2,
"max_unpool2d");
return output;
}
Tensor max_unpooling2d_forward_mps(const Tensor& self, const Tensor& indices, IntArrayRef output_size) {
auto output = at::empty({0}, self.options());
mps::max_unpool_out_mps_template(self,
indices,
output_size,
/*stride=*/{},
/*padding=*/{},
output,
/*pooling_dims=*/2,
"max_unpool2d");
return output;
}
Tensor& max_unpooling3d_forward_out_mps(const Tensor& self,
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding,
Tensor& output) {
mps::max_unpool_out_mps_template(self,
indices,
output_size,
stride,
padding,
output,
/*pooling_dims=*/3,
"max_unpool3d");
return output;
}
Tensor max_unpooling3d_forward_mps(const Tensor& self,
const Tensor& indices,
IntArrayRef output_size,
IntArrayRef stride,
IntArrayRef padding) {
auto output = at::empty({0}, self.options());
mps::max_unpool_out_mps_template(self,
indices,
output_size,
stride,
padding,
output,
/*pooling_dims=*/3,
"max_unpool3d");
return output;
}
TORCH_IMPL_FUNC(avg_pool2d_out_mps)
(const Tensor& input,
int64_t kH,
@ -1142,26 +965,4 @@ TORCH_IMPL_FUNC(avg_pool3d_out_mps)
"avg_pool3d");
}
TORCH_IMPL_FUNC(avg_pool3d_backward_out_mps)(const Tensor& grad_output,
const Tensor& input,
IntArrayRef kernel_size,
IntArrayRef stride,
IntArrayRef padding,
bool ceil_mode,
bool count_include_pad,
std::optional<int64_t> divisor_override,
const Tensor& grad_input) {
mps::avg_pool_backward_out_mps_template(grad_input,
input,
grad_output,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
/*pooling_dims=*/3,
"avg_pool3d_backward");
}
} // namespace at::native

View File

@ -719,7 +719,6 @@
dispatch:
CPU, CUDA: all_out
MPS: all_out_mps
MTIA: all_out_mtia
- func: all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
@ -809,7 +808,6 @@
CPU, Meta: arange_out
CUDA: arange_cuda_out
MPS: arange_mps_out
MTIA: arange_mtia_out
cpp_no_default_args: ['step']
# This function is a temporary hack to allow tracing of arange like constructs with dynamic
@ -1891,10 +1889,7 @@
- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)
dispatch:
CUDA: cudnn_batch_norm
- func: cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))
dispatch:
CUDA: cudnn_batch_norm_out
autogen: cudnn_batch_norm.out
# NB: You can only use this if you used cudnn_batch_norm training=True
- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)
@ -4187,13 +4182,11 @@
dispatch:
CPU: _int_mm_cpu
CUDA: _int_mm_cuda
XPU: _int_mm_xpu
- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: _int_mm_out_cpu
CUDA: _int_mm_out_cuda
XPU: _int_mm_out_xpu
- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
dispatch:
@ -4230,7 +4223,6 @@
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
dispatch:
CPU: _weight_int8pack_mm_cpu
CUDA: _weight_int8pack_mm_cuda
MPS: _weight_int8pack_mm_mps
- func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor
@ -7132,21 +7124,18 @@
dispatch:
CPU: _scaled_mm_cpu
CUDA: _scaled_mm_cuda
tags: needs_exact_strides
- func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
CPU: _scaled_mm_out_cpu
CUDA: _scaled_mm_out_cuda
tags: needs_exact_strides
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_grouped_mm_cuda
tags: needs_exact_strides
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
variants: function
@ -10498,7 +10487,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_add_scalar_kernel_slow_
CUDA: foreach_tensor_add_scalar_kernel_cuda_
MTIA: foreach_tensor_add_scalar_kernel_mtia_
autogen: _foreach_add.Scalar_out
- func: _foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]
@ -10507,7 +10495,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow
CUDA: foreach_tensor_add_list_kernel_cuda
MTIA: foreach_tensor_add_list_kernel_mtia
- func: _foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10515,7 +10502,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_add_list_kernel_slow_
CUDA: foreach_tensor_add_list_kernel_cuda_
MTIA: foreach_tensor_add_list_kernel_mtia_
autogen: _foreach_add.List_out
- func: _foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
@ -10546,7 +10532,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_add_tensor_kernel_slow_
CUDA: foreach_tensor_add_tensor_kernel_cuda_
MTIA: foreach_tensor_add_tensor_kernel_mtia_
autogen: _foreach_add.Tensor_out
- func: _foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
@ -10607,7 +10592,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_scalar_kernel_slow_
CUDA: foreach_tensor_mul_scalar_kernel_cuda_
MTIA: foreach_tensor_mul_scalar_kernel_mtia_
autogen: _foreach_mul.Scalar_out
- func: _foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]
@ -10616,7 +10600,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow
CUDA: foreach_tensor_mul_list_kernel_cuda
MTIA: foreach_tensor_mul_list_kernel_mtia
- func: _foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10624,7 +10607,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_list_kernel_slow_
CUDA: foreach_tensor_mul_list_kernel_cuda_
MTIA: foreach_tensor_mul_list_kernel_mtia_
autogen: _foreach_mul.List_out
- func: _foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]
@ -10648,7 +10630,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow
CUDA: foreach_tensor_mul_tensor_kernel_cuda
MTIA: foreach_tensor_mul_tensor_kernel_mtia
- func: _foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10656,7 +10637,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_mul_tensor_kernel_slow_
CUDA: foreach_tensor_mul_tensor_kernel_cuda_
MTIA: foreach_tensor_mul_tensor_kernel_mtia_
autogen: _foreach_mul.Tensor_out
- func: _foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]
@ -10953,7 +10933,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow
CUDA: foreach_tensor_addcmul_scalar_cuda
MTIA: foreach_tensor_addcmul_scalar_mtia
- func: _foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -10975,7 +10954,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_addcmul_scalar_slow_
CUDA: foreach_tensor_addcmul_scalar_cuda_
MTIA: foreach_tensor_addcmul_scalar_mtia_
autogen: _foreach_addcmul.Scalar_out
- func: _foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()
@ -11000,7 +10978,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_abs_slow
CUDA: foreach_tensor_abs_cuda
MTIA: foreach_tensor_abs_mtia
- func: _foreach_abs_(Tensor(a!)[] self) -> ()
device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices
@ -11008,7 +10985,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_abs_slow_
CUDA: foreach_tensor_abs_cuda_
MTIA: foreach_tensor_abs_mtia_
autogen: _foreach_abs.out
- func: _foreach_acos(Tensor[] self) -> Tensor[]
@ -11343,7 +11319,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_norm_slow
CUDA: foreach_tensor_norm_cuda
MTIA: foreach_tensor_norm_mtia
autogen: _foreach_norm.Scalar_out
- func: _foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]
@ -11516,7 +11491,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_sqrt_slow_
CUDA: foreach_tensor_sqrt_cuda_
MTIA: foreach_tensor_sqrt_mtia_
autogen: _foreach_sqrt.out
- func: _foreach_tan(Tensor[] self) -> Tensor[]
@ -11578,7 +11552,6 @@
dispatch:
CompositeExplicitAutograd: foreach_tensor_copy_list_kernel_slow_
CUDA: foreach_tensor_copy_list_kernel_cuda_
MTIA: foreach_tensor_copy_list_kernel_mtia_
autogen: _foreach_copy.out
- func: _foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out
@ -11586,7 +11559,6 @@
variants: function
dispatch:
CompositeExplicitAutograd: _foreach_copy
MTIA: foreach_tensor_copy_list_kernel_mtia
- func: bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor
dispatch:
@ -12379,7 +12351,6 @@
dispatch:
CPU: avg_pool3d_backward_out_cpu
CUDA: avg_pool3d_backward_out_cuda
MPS: avg_pool3d_backward_out_mps
MkldnnCPU: mkldnn_avg_pool3d_backward_out
- func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor
@ -12505,28 +12476,24 @@
dispatch:
CPU: max_unpooling2d_forward_out_cpu
CUDA: max_unpooling2d_forward_out_cuda
MPS: max_unpooling2d_forward_out_mps
- func: max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor
python_module: nn
dispatch:
CPU: max_unpooling2d_forward_cpu
CUDA: max_unpooling2d_forward_cuda
MPS: max_unpooling2d_forward_mps
- func: max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
dispatch:
CPU: max_unpooling3d_forward_out_cpu
CUDA: max_unpooling3d_forward_out_cuda
MPS: max_unpooling3d_forward_out_mps
- func: max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor
python_module: nn
dispatch:
CPU: max_unpooling3d_forward_cpu
CUDA: max_unpooling3d_forward_cuda
MPS: max_unpooling3d_forward_mps
- func: reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn

View File

@ -1,7 +1,7 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
--api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
RESULT_VARIABLE ret
)
@ -11,27 +11,7 @@ endif()
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api fwd_splitkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_SPLITKV kernels via Python.")
endif()
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api fwd_appendkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD_APPENDKV kernels via Python.")
endif()
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
--api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
RESULT_VARIABLE ret
)
@ -39,29 +19,15 @@ if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
endif()
# Generate the files for both fwd, fwd_splitkv, fwd_appendkv, and bwd
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
# Generate the files for both fwd and bwd
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.")
endif()
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_SPLITKV kernels.")
endif()
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_appendkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_APPENDKV kernels.")
endif()
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
)
@ -78,22 +44,6 @@ if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass")
endif()
execute_process(
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt"
RESULT_VARIABLE ret)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd_splitkv pass")
endif()
execute_process(
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt"
RESULT_VARIABLE ret)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd appendkv pass")
endif()
# Change make_kernel to make_kernel_pt for bwd
execute_process(
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt"

View File

@ -21,8 +21,6 @@ while IFS= read -r file; do
if [ -f "$file" ]; then
# Use sed to replace "make_kernel" with "make_kernel_pt" in place
sed -i 's/make_kernel/make_kernel_pt/g' "$file"
sed -i 's/\#include \"fmha_fwd.hpp\"/\#include \"fmha_fwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
sed -i 's/\#include \"fmha_bwd.hpp\"/\#include \"fmha_bwd.hpp\"\n\#include \"launch_kernel_pt.hpp\"/g' "$file"
echo "Updated: $file"
else
echo "Skipping: $file (not found)"

View File

@ -0,0 +1,100 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include <ck_tile/core.hpp>
#include <ck_tile/ops/fmha.hpp>
// keep sync with BlockAttentionBiasEnum
enum class bias_enum
{
no_bias = 0,
elementwise_bias = 1,
alibi = 2,
};
struct bias_info
{
bias_enum type;
/*
* simple dispatch logic
*
* if type == elementwise_bias:
* if rank_info == 0:
* bias is 1*1*s*s
* elif rank_info == 1:
* bias is 1*h*s*s
* elif rank_info == 2:
* bias is b*h*s*s
*
* elif type == alibi:
* if rank_info == 0:
* alibi in 1*h
* elif rank_info == 1:
* alibi in b*h
*/
int rank_info;
void serialize(std::ostream& os) const
{
if(type == bias_enum::no_bias)
os << "n";
else if(type == bias_enum::elementwise_bias)
{
os << "e";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
else if(type == bias_enum::alibi)
{
os << "alibi";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
}
static bias_info decode(std::string str)
{
bias_info info{bias_enum::no_bias, 0};
if(str == "0" || str == "n")
{
info.type = bias_enum::no_bias;
}
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
str.compare(0, 11, "elementwise") == 0)
{
info.type = bias_enum::elementwise_bias;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
str.compare(0, 5, "alibi") == 0)
{
info.type = bias_enum::alibi;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
{
bi.serialize(os);
return os;
}
};

View File

@ -0,0 +1,457 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/ops/fmha.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <mask.hpp>
#include <bias.hpp>
#include <launch_kernel_pt.hpp>
#include <type_traits>
#include <utility>
#include <variant>
struct FmhaBwdFp16
{
};
struct FmhaBwdBf16
{
};
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<FmhaBwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using GemmDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::half_t;
using OGradDataType = ck_tile::half_t;
using QGradDataType = ck_tile::half_t;
using KGradDataType = ck_tile::half_t;
using VGradDataType = ck_tile::half_t;
using BiasGradDataType = ck_tile::half_t;
};
template <>
struct FmhaBwdTypeConfig<FmhaBwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using GemmDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::bf16_t;
using OGradDataType = ck_tile::bf16_t;
using QGradDataType = ck_tile::bf16_t;
using KGradDataType = ck_tile::bf16_t;
using VGradDataType = ck_tile::bf16_t;
using BiasGradDataType = ck_tile::bf16_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_bwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
const void* o_ptr;
const void* lse_ptr;
const void* do_ptr;
void* d_ptr;
void* rand_val_ptr;
void* dq_ptr;
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t max_seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
template <typename FmhaBwdDQDKDVKernel>
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.drop_seed_offset);
}
}();
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdOGradDotOKernel>
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
{
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed);
}
else
{ // create batch mode kernel arguments
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqlen_q,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_do,
args.batch_stride_o,
args.batch_stride_lsed);
}
}();
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api
};
template <int Version = 2>
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);

View File

@ -0,0 +1,824 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <ck_tile/ops/fmha.hpp>
#include <bias.hpp>
#include <mask.hpp>
#include <rotary.hpp>
#include <launch_kernel_pt.hpp>
#include <type_traits>
#include <utility>
#include <variant>
struct FmhaFwdFp16
{
};
struct FmhaFwdBf16
{
};
struct FmhaFwdFp8
{
};
struct FmhaFwdBf8
{
};
struct FmhaFwdFp8Fp16
{
};
struct FmhaFwdFp8Bf16
{
};
template <typename DataType>
struct FmhaFwdTypeConfig;
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp8>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::fp8_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdBf8>
{
using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t;
using VDataType = ck_tile::bf8_t;
using BiasDataType = ck_tile::bf8_t;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf8_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
float scale_p;
float scale_o;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
struct fmha_fwd_splitkv_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* lse_acc_ptr;
void* o_acc_ptr;
void* lse_ptr;
void* o_ptr;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
// nullptr.
const void* cache_batch_idx;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
//
// batch mode (kvcache):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
//
// when is_gappy=true:
// seqlen_k = kargs.seqlen_k_ptr[b]
// seqstart_k_ptr[b] now store local offset of each batch
//
// when is_gappy=false:
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
ck_tile::index_t num_splits;
float scale_s;
float scale_p;
float scale_o;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o_acc;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
};
struct fmha_fwd_appendkv_args
{
void* q_ptr;
void* k_ptr;
const void* knew_ptr;
void* v_ptr;
const void* vnew_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_knew;
ck_tile::index_t batch;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
ck_tile::index_t rotary_dim;
bool has_mask;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
ck_tile::index_t stride_v;
ck_tile::index_t stride_vnew;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_knew;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.scale_p,
args.scale_o,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale_s,
args.scale_p,
args.scale_o,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
if constexpr(FmhaKernel::kIsGroupMode)
{
dim3 grids = FmhaKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids =
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids);
}
}
template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.is_gappy,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, // only used for paged-kvcache
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_acc_ptr,
args.o_acc_ptr,
args.batch,
args.seqlen_q,
args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s,
args.scale_p,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o_acc,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
args.window_size_right,
args.mask_type);
}
}();
dim3 grids = Kernel::GridSize(
args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits);
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel argumentszs
if constexpr(Kernel::kIsGroupMode)
{
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.seqstart_q_ptr,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
else
{ // create batch mode kernel arguments
return Kernel::MakeKargs(args.lse_acc_ptr,
args.o_acc_ptr,
args.lse_ptr,
args.o_ptr,
args.batch,
args.seqlen_q,
args.hdim_v,
args.num_splits,
args.scale_o,
args.stride_o_acc,
args.stride_o,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_lse,
args.batch_stride_o,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
template <typename Kernel>
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqlen_q,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.has_mask,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.stride_q,
args.stride_k,
args.stride_knew,
args.stride_v,
args.stride_vnew,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_knew,
args.nhead_stride_v,
args.nhead_stride_vnew,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_splitkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kN1_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kPadS_,
bool kPadDv_>
struct fmha_fwd_splitkv_combine_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kTileSizeS_,
ck_tile::index_t kTileSizeSk_,
ck_tile::index_t kTileSizeD_,
ck_tile::index_t kTileSizeDv_,
bool kIsVLayoutRowMajor_,
bool kPadS_,
bool kPadSk_,
bool kPadD_,
bool kPadDv_,
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_>
struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSk = kPadSk_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool has_dropout;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
struct fmha_fwd_splitkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
fmha_fwd_splitkv_args,
const ck_tile::stream_config&);
struct fmha_fwd_appendkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_v_rowmajor;
rope_enum rope_type;
};
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);

View File

@ -0,0 +1,157 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include <ck_tile/core.hpp>
#include <ck_tile/ops/fmha.hpp>
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum
{
no_mask = 0,
mask_top_left,
mask_bottom_right,
window_generic,
};
struct mask_info
{
mask_enum type;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
void serialize(std::ostream& os) const
{
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::mask_top_left)
os << "t(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "b(" << left << ":" << right << ")";
else
{
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
{
ck_tile::index_t x_total = seqlen_k;
ck_tile::index_t y_total = seqlen_q;
mask_info tmp;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string t = str.substr(0, found_0);
std::string v = str.substr(found_0 + 1);
if(t == "xt" || t == "xb")
{
// xformer style sliding window attn from top-left
ck_tile::index_t window_size = atoi(v.c_str());
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
if(window_size > 0)
{
left_size = window_size / 2;
right_size = window_size - 1 - left_size;
}
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, t == "xt");
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = left_size;
tmp.right = right_size;
}
else
{
auto found_1 = v.find(",");
if(found_1 == std::string::npos)
{
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
assert(0);
}
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
if(t == "t")
{
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
else
{
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
assert(0);
}
}
}
else
{
auto set_causal_top_left = [&]() {
tmp.type = mask_enum::mask_top_left;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
};
auto set_causal_bottom_right = [&]() {
tmp.type = mask_enum::mask_bottom_right;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
};
if(str == "t")
set_causal_top_left();
else if(str == "b")
set_causal_bottom_right();
else
{
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
if(tmp.type == mask_enum::mask_top_left)
{
set_causal_top_left();
}
else if(tmp.type == mask_enum::mask_bottom_right)
{
set_causal_bottom_right();
}
}
}
return tmp;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);
return os;
}
};

View File

@ -22,7 +22,6 @@ fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
dtype,
false, // is_group_mode
true, // is_v_rowmajor
false, // has_logits_soft_cap
mask.type,
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
has_lse,
@ -86,7 +85,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
ck_tile::index_t stride_attn_bias = 0;
ck_tile::index_t batch_stride_bias = 0;
ck_tile::index_t nhead_stride_bias = 0;
if (attn_bias_.has_value()) {
auto a_b = attn_bias_.value();
CHECK_DEVICE(a_b);
@ -96,6 +94,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
nhead_stride_bias = a_b.stride(1);
batch_stride_bias = a_b.stride(0);
}
return fmha_fwd_args{q.data_ptr(),
k.data_ptr(),
v.data_ptr(),
@ -117,7 +116,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
0.0f, // logits_soft_cap
stride_q,
stride_k,
stride_v,
@ -141,7 +139,6 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
-1, // min_seqlen_q
p_dropout,
has_dropout_randval,
drop_seed_offset};

View File

@ -20,7 +20,6 @@ fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
dtype,
true, // is_group_mode
true, // is_v_rowmajor
false, // has_logits_soft_cap
mask.type,
enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias,
has_lse,
@ -118,7 +117,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
softmax_scale, // scale_s
1, // scale_p
1, // scale_o
0.0f, // logits_soft_cap
stride_q,
stride_k,
stride_v,
@ -142,7 +140,6 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
-1, // min_seqlen_q
p_dropout,
has_dropout_randval,
drop_seed_offset};

View File

@ -0,0 +1,84 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum class rope_enum
{
none = 0,
interleaved = 1,
half_rotated = 2,
};
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
generate_rotary_cos_sin(ck_tile::index_t seqlen,
ck_tile::index_t rotary_dim,
std::optional<unsigned> seed = std::nullopt)
{
// return dummy tensors if we won't apply RoPE at all
if(rotary_dim <= 0)
{
ck_tile::HostTensor<DataType> dummy({1, 1});
return std::make_tuple(dummy, dummy);
}
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
const ck_tile::index_t num_rows = seqlen * 2;
const ck_tile::index_t num_cols = rotary_dim / 2;
using std::begin, std::end;
ck_tile::HostTensor<float> angle({num_rows, num_cols});
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::cos(origin_value));
});
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::sin(origin_value));
});
return std::make_tuple(cos, sin);
}
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
const ck_tile::HostTensor<DataType>& sin,
ck_tile::index_t seqlen_offset,
ck_tile::index_t seqlen)
{
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
const ck_tile::index_t num_rows = seqlen;
const ck_tile::index_t num_cols = cos.get_length(1);
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
return std::make_tuple(cos_pt, sin_pt);
}

View File

@ -5,12 +5,6 @@ import os
import sys
# Run only this selected group of models, leave this empty to run everything
TORCHBENCH_ONLY_MODELS = [
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
]
# Note - hf and timm have their own version of this, torchbench does not
# TODO(voz): Someday, consolidate all the files into one runner instead of a shim like this...
def model_names(filename: str) -> set[str]:
@ -23,8 +17,6 @@ def model_names(filename: str) -> set[str]:
if len(line_parts) == 1:
line_parts = line.split(",")
model_name = line_parts[0]
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
continue
names.add(model_name)
return names

View File

@ -9,7 +9,6 @@ import copy
import csv
import dataclasses
import functools
import gc
import importlib
import itertools
import json
@ -2388,7 +2387,6 @@ class BenchmarkRunner:
)
def warmup(fn, model, example_inputs, mode, niters=10):
gc.collect()
peak_mem = 0
start_stats = get_dynamo_stats()
try:
@ -2550,7 +2548,6 @@ class BenchmarkRunner:
return experiment(*self.maybe_cast(model, example_inputs))
def warmup(fn, model, example_inputs, mode, niters=5):
gc.collect()
peak_mem = 0
start_stats = get_dynamo_stats()
try:

View File

@ -106,11 +106,6 @@ finally:
# on A100 GPUs - 40 GB.
BATCH_SIZE_KNOWN_MODELS = {}
# Run only this selected group of models, leave this empty to run everything
TORCHBENCH_ONLY_MODELS = [
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
]
# TODO(sdym): use batch-size-file parameter of common.main, like torchbench.py
# Get the list of models and their batch sizes
@ -121,8 +116,6 @@ with open(MODELS_FILENAME) as fh:
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(",")
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
continue
batch_size = int(batch_size)
BATCH_SIZE_KNOWN_MODELS[model_name] = batch_size
assert len(BATCH_SIZE_KNOWN_MODELS)

View File

@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,26800000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1009000000,0.1
basic_modules_ListOfLinears_eager,compile_time_instruction_count,969100000,0.1
@ -82,7 +82,7 @@ mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,8787000000,0.1
basic_NestedModule_eager,compile_time_instruction_count,8348000000,0.1

1 add_loop_eager compile_time_instruction_count 3070000000 0.1
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3959000000 0.1
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10650000000 0.1
20 mm_loop_inductor_gpu compile_time_instruction_count 4461000000 0.1
21 mm_loop_inductor_dynamic_gpu compile_time_instruction_count 8417000000 0.1
22 basic_NestedModule_eager compile_time_instruction_count 8787000000 8348000000 0.1
23 basic_InlineMod_eager compile_time_instruction_count 7464000000 0.1
24
82
83
84
85
86
87
88

View File

@ -39,20 +39,13 @@ finally:
from timm.models import create_model
TIMM_MODELS = {}
# Run only this selected group of models, leave this empty to run everything
TORCHBENCH_ONLY_MODELS = [
m.strip() for m in os.getenv("TORCHBENCH_ONLY_MODELS", "").split(",") if m.strip()
]
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")
with open(filename) as fh:
lines = fh.readlines()
lines = [line.rstrip() for line in lines]
for line in lines:
model_name, batch_size = line.split(" ")
if TORCHBENCH_ONLY_MODELS and model_name not in TORCHBENCH_ONLY_MODELS:
continue
TIMM_MODELS[model_name] = int(batch_size)

View File

@ -599,7 +599,6 @@ libtorch_nativert_sources = [
"torch/nativert/graph/GraphSignature.cpp",
"torch/nativert/graph/Serialization.cpp",
"torch/nativert/graph/TensorMeta.cpp",
"torch/nativert/graph/GraphUtils.cpp",
"torch/nativert/executor/DelegateExecutor.cpp",
"torch/nativert/executor/Placement.cpp",
"torch/nativert/executor/ExecutionPlanner.cpp",

View File

@ -224,7 +224,7 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
// check if the key is unrecognized.
if (device_config_parser_hook_) {
TORCH_CHECK(
getKeys().find(key) != getKeys().end(),
keys_.find(key) != keys_.end(),
"Unrecognized key '",
key,
"' in Accelerator allocator config.");

View File

@ -220,24 +220,11 @@ class C10_API AcceleratorAllocatorConfig {
return instance().last_allocator_settings_;
}
// Use `Construct On First Use Idiom` to avoid `Static Initialization Order`
// issue.
static std::unordered_set<std::string>& getMutableKeys() {
static std::unordered_set<std::string> keys{
"max_split_size_mb",
"max_non_split_rounding_mb",
"garbage_collection_threshold",
"roundup_power2_divisions",
"expandable_segments",
"pinned_use_background_threads"};
return keys;
}
// Returns the set of valid keys for the allocator configuration.
// This set is used to validate the presence and correctness of keys in
// device-specific configuration parsers.
static const std::unordered_set<std::string>& getKeys() {
return getMutableKeys();
return keys_;
}
// Registers a device-specific configuration parser hook and its key. This
@ -251,10 +238,9 @@ class C10_API AcceleratorAllocatorConfig {
std::function<void(const std::string&)>&& hook,
const std::unordered_set<std::string>& keys) {
device_config_parser_hook_ = std::move(hook);
auto& mutable_keys = getMutableKeys();
for (auto& key : keys) {
TORCH_CHECK(
mutable_keys.insert(key).second,
keys_.insert(key).second,
"Duplicated key '",
key,
"' found in device-specific configuration parser hook registration");
@ -340,6 +326,17 @@ class C10_API AcceleratorAllocatorConfig {
// their own environment configuration extensions.
inline static std::function<void(const std::string&)>
device_config_parser_hook_{nullptr};
// A set of valid configuration keys, including both common and
// device-specific options. This set is used to validate the presence and
// correctness of keys during parsing.
inline static std::unordered_set<std::string> keys_{
"max_split_size_mb",
"max_non_split_rounding_mb",
"garbage_collection_threshold",
"roundup_power2_divisions",
"expandable_segments",
"pinned_use_background_threads"};
};
C10_API inline void setAllocatorSettings(const std::string& env) {

View File

@ -1,10 +0,0 @@
#include <c10/core/CachingDeviceAllocator.h>
namespace c10 {
// Ensures proper DLL export of this pure virtual base class on Windows,
// since it's mainly used in other DLLs outside c10.dll.
DeviceAllocator::DeviceAllocator() = default;
DeviceAllocator::~DeviceAllocator() = default;
} // namespace c10

View File

@ -1,7 +1,6 @@
#pragma once
#include <c10/core/Allocator.h>
#include <c10/core/Stream.h>
namespace c10::CachingDeviceAllocator {
@ -60,55 +59,3 @@ struct DeviceStats {
};
} // namespace c10::CachingDeviceAllocator
namespace c10 {
using CaptureId_t = unsigned long long;
// first is set if the instance is created by Graph mode capture_begin.
// second is set if the instance is created by Graph mode graph_pool_handle.
using MempoolId_t = std::pair<CaptureId_t, CaptureId_t>;
struct C10_API DeviceAllocator : public c10::Allocator {
DeviceAllocator();
~DeviceAllocator() override;
// Returns true if the allocator has been properly initialized and is ready
// for use
virtual bool initialized() = 0;
// Releases all cached device memory from the specified memory pool back to
// the system
virtual void emptyCache(MempoolId_t mempool_id = {0, 0}) = 0;
// Associates a memory allocation with a stream to establish dependency
// tracking. Prevents memory reuse until all operations on the specified
// stream complete
virtual void recordStream(const DataPtr& ptr, c10::Stream stream) = 0;
// Retrieves comprehensive memory statistics for the specified device,
// including allocation patterns, usage metrics
virtual CachingDeviceAllocator::DeviceStats getDeviceStats(
c10::DeviceIndex device) = 0;
// Resets cumulative allocation statistics for the specified device to zero
virtual void resetAccumulatedStats(c10::DeviceIndex device) = 0;
// Resets peak memory usage statistics for the specified device
virtual void resetPeakStats(c10::DeviceIndex device) = 0;
};
// This function is used to get the DeviceAllocator for a specific device type
// and keep backward compatibility with c10::GetAllocator.
C10_API inline DeviceAllocator* getDeviceAllocator(const DeviceType& t) {
TORCH_CHECK(
t != DeviceType::CPU,
"getDeviceAllocator is not supported for CPU device type.");
auto* allocator = c10::GetAllocator(t);
auto* device_allocator = dynamic_cast<DeviceAllocator*>(allocator);
TORCH_INTERNAL_ASSERT(
device_allocator, "Allocator for ", t, " is not a DeviceAllocator.");
return device_allocator;
}
} // namespace c10

View File

@ -191,17 +191,11 @@ class C10_API Scalar {
isIntegral() const {
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag;
}
bool isIntegral(bool includeBool) const {
return Tag::HAS_i == tag || Tag::HAS_si == tag || Tag::HAS_u == tag ||
(includeBool && isBoolean());
}
// See Note [Meaning of HAS_u]
bool isUnsigned() const {
return Tag::HAS_u == tag || (Tag::HAS_i == tag && v.i >= 0);
}
bool isComplex() const {
return Tag::HAS_z == tag;
}

View File

@ -19,16 +19,25 @@
#include <array>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <ostream>
#include <type_traits>
#include <unordered_map>
#include <torch/headeronly/core/ScalarType.h>
namespace c10 {
// [dtype Macros note] For the macros below:
// dummy struct for uint1 to uint7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_uint1_7_t {};
// dummy struct for int1 to int7, actual functionality
// of these dtypes will be implemented in python with Tensor subclass
template <unsigned int N>
struct dummy_int1_7_t {};
// For the macros below:
//
// For users: If you want to macro some code for all non-QInt scalar types
// (i.e. types with complete information, you probably want one of the
@ -48,6 +57,56 @@ namespace c10 {
// some old PRs where we added new dtypes (check history of this file) can
// help give you an idea where to start.
// NB: Order matters for this macro; it is relied upon in
// _promoteTypesLookup and the serialization format.
#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(at::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(c10::complex<c10::Half>, ComplexHalf) /* 8 */ \
_(c10::complex<float>, ComplexFloat) /* 9 */ \
_(c10::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */ \
_(c10::quint2x4, QUInt2x4) /* 17 */ \
_(c10::bits1x8, Bits1x8) /* 18 */ \
_(c10::bits2x4, Bits2x4) /* 19 */ \
_(c10::bits4x2, Bits4x2) /* 20 */ \
_(c10::bits8, Bits8) /* 21 */ \
_(c10::bits16, Bits16) /* 22 */ \
_(c10::Float8_e5m2, Float8_e5m2) /* 23 */ \
_(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */ \
_(c10::Float8_e5m2fnuz, Float8_e5m2fnuz) /* 25 */ \
_(c10::Float8_e4m3fnuz, Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */ \
_(c10::dummy_uint1_7_t<1>, UInt1) /* 30 */ \
_(c10::dummy_uint1_7_t<2>, UInt2) /* 31 */ \
_(c10::dummy_uint1_7_t<3>, UInt3) /* 32 */ \
_(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \
_(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \
_(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \
_(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \
_(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \
_(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \
_(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \
_(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \
_(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \
_(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \
_(c10::dummy_int1_7_t<7>, Int7) /* 43 */ \
_(c10::Float8_e8m0fnu, Float8_e8m0fnu) /* 44 */ \
_(c10::Float4_e2m1fn_x2, Float4_e2m1fn_x2) /* 45 */
// If you want to support ComplexHalf for real, add ComplexHalf
// into this macro (and change the name). But beware: convert()
// doesn't work for all the conversions you need...
@ -93,6 +152,17 @@ namespace c10 {
_(at::Float8_e4m3fnuz, Float8_e4m3fnuz) \
_(at::Float8_e8m0fnu, Float8_e8m0fnu)
enum class ScalarType : int8_t {
#define DEFINE_ST_ENUM_VAL_(_1, n) n,
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ST_ENUM_VAL_)
#undef DEFINE_ENUM_ST_ENUM_VAL_
Undefined,
NumOptions
};
constexpr uint16_t NumScalarTypes =
static_cast<uint16_t>(ScalarType::NumOptions);
namespace impl {
// These are used to map ScalarTypes to C++ types.

Some files were not shown because too many files have changed in this diff Show More