Compare commits

..

1 Commits

Author SHA1 Message Date
a5e8b0ad38 Trying to reduce flash-deps
ghstack-source-id: 8ba7b23dfde594e126977930e54395405573a598
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144120
2025-01-09 16:40:01 -08:00
564 changed files with 2821 additions and 11656 deletions

View File

@ -3,9 +3,6 @@ set -eux -o pipefail
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
# cuda arm build for Grace Hopper solely
export TORCH_CUDA_ARCH_LIST="9.0"
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
source $SCRIPTPATH/aarch64_ci_setup.sh

View File

@ -0,0 +1,5 @@
0.8b
manylinux_2_28
rocm6.2
6f8cbcac8a92775291bb1ba8f514d4beb350baf4
e938def5d32869fe2e00aec0300f354c9f157867bebdf2e104d732b94cb238d8

View File

@ -268,7 +268,7 @@ case "$image" in
PROTOBUF=yes
DB=yes
VISION=yes
ROCM_VERSION=6.2.4
ROCM_VERSION=6.1
NINJA_VERSION=1.9.0
CONDA_CMAKE=yes
TRITON=yes
@ -279,7 +279,7 @@ case "$image" in
PROTOBUF=yes
DB=yes
VISION=yes
ROCM_VERSION=6.3
ROCM_VERSION=6.2.4
NINJA_VERSION=1.9.0
CONDA_CMAKE=yes
TRITON=yes
@ -497,7 +497,7 @@ docker build \
--build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \
--build-arg "KATEX=${KATEX:-}" \
--build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a;gfx942}" \
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a}" \
--build-arg "IMAGE_NAME=${IMAGE_NAME}" \
--build-arg "UCX_COMMIT=${UCX_COMMIT}" \
--build-arg "UCC_COMMIT=${UCC_COMMIT}" \

View File

@ -113,6 +113,13 @@ COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt
# Install AOTriton (Early fail)
COPY ./aotriton_version.txt aotriton_version.txt
COPY ./common/common_utils.sh common_utils.sh
COPY ./common/install_aotriton.sh install_aotriton.sh
RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"]
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
# Install ccache/sccache (do this last, so we get priority in PATH)
COPY ./common/install_cache.sh install_cache.sh
ENV PATH /opt/cache/bin:$PATH

View File

@ -1,7 +1,7 @@
set -euo pipefail
readonly version=v24.04
readonly src_host=https://github.com/ARM-software
readonly src_host=https://review.mlplatform.org/ml
readonly src_repo=ComputeLibrary
# Clone ACL

View File

@ -0,0 +1,23 @@
#!/bin/bash
set -ex
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
TARBALL='aotriton.tar.gz'
# This read command alwasy returns with exit code 1
read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true
ARCH=$(uname -m)
AOTRITON_INSTALL_PREFIX="$1"
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.gz"
cd "${AOTRITON_INSTALL_PREFIX}"
# Must use -L to follow redirects
curl -L --retry 3 -o "${TARBALL}" "${AOTRITON_URL}"
ACTUAL_SHA256=$(sha256sum "${TARBALL}" | cut -d " " -f 1)
if [ "${SHA256}" != "${ACTUAL_SHA256}" ]; then
echo -n "Error: The SHA256 of downloaded tarball is ${ACTUAL_SHA256},"
echo " which does not match the expected value ${SHA256}."
exit
fi
tar xf "${TARBALL}" && rm -rf "${TARBALL}"

View File

@ -62,22 +62,6 @@ install_ubuntu() {
sqlite3 $kdb "PRAGMA journal_mode=off; PRAGMA VACUUM;"
done
# ROCm 6.3 had a regression where initializing static code objects had significant overhead
if [[ $(ver $ROCM_VERSION) -eq $(ver 6.3) ]]; then
# clr build needs CppHeaderParser but can only find it using conda's python
/opt/conda/bin/python -m pip install CppHeaderParser
git clone https://github.com/ROCm/HIP -b rocm-6.3.x
HIP_COMMON_DIR=$(readlink -f HIP)
git clone https://github.com/jeffdaily/clr -b release/rocm-rel-6.3-statco-hotfix
mkdir -p clr/build
pushd clr/build
cmake .. -DCLR_BUILD_HIP=ON -DHIP_COMMON_DIR=$HIP_COMMON_DIR
make -j
cp hipamd/lib/libamdhip64.so.6.3.* /opt/rocm/lib/libamdhip64.so.6.3.*
popd
rm -rf HIP clr
fi
# Cleanup
apt-get autoclean && apt-get clean
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*

View File

@ -92,6 +92,13 @@ RUN apt-get update -y && \
RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh
RUN bash ./install_rocm_magma.sh && rm install_rocm_magma.sh
# Install AOTriton
COPY ./common/common_utils.sh common_utils.sh
COPY ./aotriton_version.txt aotriton_version.txt
COPY ./common/install_aotriton.sh install_aotriton.sh
RUN bash ./install_aotriton.sh /opt/rocm && rm install_aotriton.sh aotriton_version.txt
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
FROM ${BASE_TARGET} as final
COPY --from=openssl /opt/openssl /opt/openssl
# Install patchelf

View File

@ -198,3 +198,10 @@ ADD ./common/install_rocm_magma.sh install_rocm_magma.sh
RUN bash ./install_rocm_magma.sh && rm install_rocm_magma.sh
ADD ./common/install_miopen.sh install_miopen.sh
RUN bash ./install_miopen.sh ${ROCM_VERSION} && rm install_miopen.sh
# Install AOTriton
COPY ./common/common_utils.sh common_utils.sh
COPY ./aotriton_version.txt aotriton_version.txt
COPY ./common/install_aotriton.sh install_aotriton.sh
RUN bash ./install_aotriton.sh /opt/rocm && rm install_aotriton.sh aotriton_version.txt
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton

View File

@ -304,7 +304,7 @@ pytest-cpp==2.3.0
#Pinned versions: 2.3.0
#test that import:
z3-solver==4.12.6.0
z3-solver==4.12.2.0
#Description: The Z3 Theorem Prover Project
#Pinned versions:
#test that import:

View File

@ -107,6 +107,13 @@ COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt
# Install AOTriton
COPY ./aotriton_version.txt aotriton_version.txt
COPY ./common/common_utils.sh common_utils.sh
COPY ./common/install_aotriton.sh install_aotriton.sh
RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"]
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
# This is needed by sccache
COPY ./common/install_openssl.sh install_openssl.sh
ENV OPENSSL_ROOT_DIR /opt/openssl

View File

@ -53,10 +53,22 @@ cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6"
case ${CUDA_VERSION} in
12.6)
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0+PTX"
if [[ "$GPU_ARCH_TYPE" = "cuda-aarch64" ]]; then
TORCH_CUDA_ARCH_LIST="9.0"
else
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0+PTX"
fi
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
;;
12.4)
if [[ "$GPU_ARCH_TYPE" = "cuda-aarch64" ]]; then
TORCH_CUDA_ARCH_LIST="9.0"
else
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
fi
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
;;
12.1)
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
;;

View File

@ -186,6 +186,15 @@ do
OS_SO_FILES[${#OS_SO_FILES[@]}]=$file_name # Append lib to array
done
# FIXME: Temporary until https://github.com/pytorch/pytorch/pull/137443 lands
# Install AOTriton
if [ -e ${PYTORCH_ROOT}/.ci/docker/aotriton_version.txt ]; then
cp -a ${PYTORCH_ROOT}/.ci/docker/aotriton_version.txt aotriton_version.txt
bash ${PYTORCH_ROOT}/.ci/docker/common/install_aotriton.sh ${ROCM_HOME} && rm aotriton_version.txt
export AOTRITON_INSTALLED_PREFIX=${ROCM_HOME}/aotriton
ROCM_SO_FILES+=("libaotriton_v2.so")
fi
# rocBLAS library files
ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library
ROCBLAS_LIB_DST=lib/rocblas/library
@ -257,6 +266,20 @@ RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC))
DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/})
DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/})
# PyTorch 2.6+ (AOTriton 0.8b+)
# AKS = "AOTriton Kernel Storage", a file format to store GPU kernels compactly
if (( $(echo "${PYTORCH_VERSION} 2.6" | awk '{print ($1 >= $2)}') )); then
LIBAOTRITON_DIR=$(find "$ROCM_HOME/lib/" -name "libaotriton_v2.so" -printf '%h\n')
if [[ -z ${LIBAOTRITON_DIR} ]]; then
LIBAOTRITON_DIR=$(find "$ROCM_HOME/" -name "libaotriton_v2.so" -printf '%h\n')
fi
AKS_FILES=($(find "${LIBAOTRITON_DIR}/aotriton.images" -type f -name '*.aks?' -printf '%P\n'))
AKS_SRC="${LIBAOTRITON_DIR}/aotriton.images"
AKS_DST="lib/aotriton.images"
DEPS_AUX_SRCLIST+=(${AKS_FILES[@]/#/${AKS_SRC}/})
DEPS_AUX_DSTLIST+=(${AKS_FILES[@]/#/${AKS_DST}/})
fi
echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}"
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"

View File

@ -228,7 +228,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then
export CMAKE_BUILD_TYPE=RelWithAssert
fi
# Do not change workspace permissions for ROCm and s390x CI jobs
# Do not change workspace permissions for ROCm CI jobs
# as it can leave workspace with bad permissions for cancelled jobs
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /var/lib/jenkins/workspace ]]; then
# Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96)

View File

@ -12,9 +12,9 @@ export TERM=vt100
# shellcheck source=./common.sh
source "$(dirname "${BASH_SOURCE[0]}")/common.sh"
# Do not change workspace permissions for ROCm and s390x CI jobs
# Do not change workspace permissions for ROCm CI jobs
# as it can leave workspace with bad permissions for cancelled jobs
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /var/lib/jenkins/workspace ]]; then
if [[ "$BUILD_ENVIRONMENT" != *rocm* && -d /var/lib/jenkins/workspace ]]; then
# Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96)
WORKSPACE_ORIGINAL_OWNER_ID=$(stat -c '%u' "/var/lib/jenkins/workspace")
cleanup_workspace() {
@ -86,13 +86,6 @@ if [[ "$BUILD_ENVIRONMENT" == *clang9* || "$BUILD_ENVIRONMENT" == *xpu* ]]; then
export VALGRIND=OFF
fi
if [[ "$BUILD_ENVIRONMENT" == *s390x* ]]; then
# There are additional warnings on s390x, maybe due to newer gcc.
# Skip this check for now
export VALGRIND=OFF
fi
if [[ "${PYTORCH_TEST_RERUN_DISABLED_TESTS}" == "1" ]] || [[ "${CONTINUE_THROUGH_ERROR}" == "1" ]]; then
# When rerunning disable tests, do not generate core dumps as it could consume
# the runner disk space when crashed tests are run multiple times. Running out
@ -541,7 +534,7 @@ test_perf_for_dashboard() {
--dynamic-batch-only "$@" \
--output "$TEST_REPORTS_DIR/${backend}_dynamic_${suite}_${dtype}_${mode}_${device}_${target}.csv"
fi
if [[ "$DASHBOARD_TAG" == *cppwrapper-true* ]]; then
if [[ "$DASHBOARD_TAG" == *cppwrapper-true* ]] && [[ "$mode" == "inference" ]]; then
TORCHINDUCTOR_CPP_WRAPPER=1 $TASKSET python "benchmarks/dynamo/$suite.py" \
"${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --disable-cudagraphs "$@" \
--output "$TEST_REPORTS_DIR/${backend}_cpp_wrapper_${suite}_${dtype}_${mode}_${device}_${target}.csv"
@ -917,20 +910,10 @@ test_libtorch_api() {
else
# Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest"
# On s390x, pytorch is built without llvm.
# Even if it would be built with llvm, llvm currently doesn't support used features on s390x and
# test fails with errors like:
# JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer
# unknown file: Failure
# C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) }
if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
fi
python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
fi
# quantization is not fully supported on s390x yet
if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* && "${BUILD_ENVIRONMENT}" != *asan* && "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* && "${BUILD_ENVIRONMENT}" != *asan* ]]; then
# NB: This test is not under TORCH_BIN_DIR but under BUILD_BIN_DIR
export CPP_TESTS_DIR="${BUILD_BIN_DIR}"
python test/run_test.py --cpp --verbose -i cpp/static_runtime_test

View File

@ -1,9 +1,8 @@
---
# NOTE there must be no spaces before the '-', so put the comma last.
# The check bugprone-unchecked-optional-access is also turned on.
# Note that it can cause clang-tidy to hang randomly. The tracking issue
# The check bugprone-unchecked-optional-access is also turned off atm
# because it causes clang-tidy to hang randomly. The tracking issue
# can be found at https://github.com/llvm/llvm-project/issues/69369.
# When that happens, we can disable it on the problematic code by NOLINT.
InheritParentConfig: true
Checks: '
bugprone-*,
@ -13,6 +12,7 @@ bugprone-*,
-bugprone-lambda-function-name,
-bugprone-reserved-identifier,
-bugprone-swapped-arguments,
-bugprone-unchecked-optional-access,
clang-analyzer-core.*,
clang-analyzer-cplusplus.*,
clang-analyzer-nullability.*,

View File

@ -5,7 +5,7 @@ body:
- type: markdown
attributes:
value: >
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/pytorch/pytorch/issues?q=is%3Aissue+sort%3Acreated-desc+). Note: Please write your bug report in English to ensure it can be understood and addressed by the development team.
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/pytorch/pytorch/issues?q=is%3Aissue+sort%3Acreated-desc+).
- type: textarea
attributes:
label: 🐛 Describe the bug

View File

@ -2,10 +2,6 @@ name: 📚 Documentation
description: Report an issue related to https://pytorch.org/docs/stable/index.html
body:
- type: markdown
attributes:
value: >
#### Note: Please report your documentation issue in English to ensure it can be understood and addressed by the development team.
- type: textarea
attributes:
label: 📚 The doc issue

View File

@ -2,10 +2,6 @@ name: 🚀 Feature request
description: Submit a proposal/request for a new PyTorch feature
body:
- type: markdown
attributes:
value: >
#### Note: Please write your feature request in English to ensure it can be understood and addressed by the development team.
- type: textarea
attributes:
label: 🚀 The feature, motivation and pitch

View File

@ -3,10 +3,6 @@ description: Create a report to help us reproduce and fix the bug
labels: ["oncall: pt2"]
body:
- type: markdown
attributes:
value: >
#### Note: Please write your bug report in English to ensure it can be understood and addressed by the development team.
- type: markdown
attributes:
value: >

View File

@ -17,10 +17,6 @@ runs:
set -ex
diskspace_cutoff=${{ inputs.diskspace-cutoff }}
docker_root_dir=$(docker info -f '{{.DockerRootDir}}')
if [ ! -d "$docker_root_dir" ]; then
echo "Docker root directory ($docker_root_dir) does not exist. Skipping disk space check."
exit 0
fi
diskspace=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
msg="Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified"
if [[ "$diskspace" -ge "$diskspace_cutoff" ]] ; then

View File

@ -5,6 +5,20 @@ description: Set up ROCm host for CI
runs:
using: composite
steps:
- name: Set DOCKER_HOST
shell: bash
run: echo "DOCKER_HOST=unix:///run/user/$(id -u)/docker.sock" >> "${GITHUB_ENV}"
- name: Remove leftover Docker config file
shell: bash
continue-on-error: true
run: |
set -ex
cat ~/.docker/config.json || true
# https://stackoverflow.com/questions/64455468/error-when-logging-into-ecr-with-docker-login-error-saving-credentials-not
rm -f ~/.docker/config.json
- name: Stop all running docker containers
if: always()
shell: bash
@ -97,10 +111,8 @@ runs:
shell: bash
run: |
# All GPUs are visible to the runner; visibility, if needed, will be set by run_test.py.
# Add render group for container creation.
render_gid=`cat /etc/group | grep render | cut -d: -f3`
# The --group-add daemon and --group-add bin are needed in the Ubuntu 24.04 and Almalinux OSs respectively.
# This is due to the device files (/dev/kfd & /dev/dri) being owned by video group on bare metal.
# This video group ID maps to subgid 1 inside the docker image due to the /etc/subgid entries.
# The group name corresponding to group ID 1 can change depending on the OS, so both are necessary.
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device /dev/dri --group-add video --group-add $render_gid --group-add daemon --group-add bin --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --network=host" >> "${GITHUB_ENV}"
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon --group-add bin" >> "${GITHUB_ENV}"

View File

@ -1,12 +1,12 @@
# Self-Hosted IBM Z Github Actions Runner.
# Temporary image: amd64 dependencies.
FROM --platform=linux/amd64 docker.io/ubuntu:24.04 as ld-prefix
FROM docker.io/amd64/ubuntu:23.10 as ld-prefix
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get -y install ca-certificates libicu74 libssl3
RUN apt-get update && apt-get -y install ca-certificates libicu72 libssl3
# Main image.
FROM --platform=linux/s390x docker.io/ubuntu:24.04
FROM docker.io/s390x/ubuntu:23.10
# Packages for pytorch building and testing.
ENV DEBIAN_FRONTEND=noninteractive

View File

@ -219,10 +219,6 @@ jobs:
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then
JENKINS_USER=
USED_IMAGE="${DOCKER_IMAGE_S390X}"
# ensure that docker container cleanly exits in 12 hours
# if for some reason cleanup action doesn't stop container
# when job is cancelled
DOCKER_SHELL_CMD="sleep 12h"
# since some steps are skipped on s390x, if they are necessary, run them here
env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}"
@ -230,7 +226,6 @@ jobs:
else
JENKINS_USER="--user jenkins"
USED_IMAGE="${DOCKER_IMAGE}"
DOCKER_SHELL_CMD=
fi
# Leaving 1GB for the runner and other things
@ -240,7 +235,7 @@ jobs:
TOTAL_MEMORY_WITH_SWAP=$(("${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}" + 3))
# detached container should get cleaned up by teardown_ec2_linux
# Used for JENKINS_USER and DOCKER_SHELL_CMD, which can be empty
# Used for JENKINS_USER, which can be empty
# shellcheck disable=SC2086
container_name=$(docker run \
-e BUILD_ENVIRONMENT \
@ -271,8 +266,7 @@ jobs:
${JENKINS_USER} \
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
-w /var/lib/jenkins/workspace \
"${USED_IMAGE}" \
${DOCKER_SHELL_CMD}
"${USED_IMAGE}"
)
docker exec -t "${container_name}" sh -c '.ci/pytorch/build.sh'
@ -338,5 +332,6 @@ jobs:
shell: bash
run: |
# on s390x stop the container for clean worker stop
docker stop -a || true
docker kill -a || true
# ignore expansion of "docker ps -q" since it could be empty
# shellcheck disable=SC2046
docker stop $(docker ps -q) || true

View File

@ -81,7 +81,7 @@ jobs:
steps:
- name: Setup SSH (Click me for login details)
uses: pytorch/test-infra/.github/actions/setup-ssh@main
if: ${{ !contains(matrix.runner, 'gcp.a100') && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
if: ${{ !contains(matrix.runner, 'gcp.a100') }}
with:
github-secret: ${{ secrets.GITHUB_TOKEN }}
instructions: |
@ -95,10 +95,9 @@ jobs:
- name: Setup Linux
uses: ./.github/actions/setup-linux
if: inputs.build-environment != 'linux-s390x-binary-manywheel'
- name: configure aws credentials
if : ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
if : ${{ inputs.aws-role-to-assume != '' }}
uses: aws-actions/configure-aws-credentials@v3
with:
role-to-assume: ${{ inputs.aws-role-to-assume }}
@ -108,13 +107,11 @@ jobs:
- name: Calculate docker image
id: calculate-docker-image
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
if: inputs.build-environment != 'linux-s390x-binary-manywheel'
with:
docker-image-name: ${{ inputs.docker-image }}
- name: Use following to pull public copy of the image
id: print-ghcr-mirror
if: inputs.build-environment != 'linux-s390x-binary-manywheel'
env:
ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
shell: bash
@ -124,7 +121,6 @@ jobs:
- name: Pull docker image
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
if: inputs.build-environment != 'linux-s390x-binary-manywheel'
with:
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
@ -170,7 +166,6 @@ jobs:
with:
name: ${{ inputs.build-environment }}
s3-bucket: ${{ inputs.s3-bucket }}
use-gha: ${{ inputs.use-gha }}
- name: Download TD artifacts
continue-on-error: true
@ -267,21 +262,9 @@ jobs:
# comes from https://github.com/pytorch/test-infra/pull/6058
TOTAL_MEMORY_WITH_SWAP=$(("${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}" + 3))
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then
SHM_OPTS=
JENKINS_USER=
# since some steps are skipped on s390x, if they are necessary, run them here
env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}"
env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}"
else
SHM_OPTS="--shm-size=${SHM_SIZE}"
JENKINS_USER="--user jenkins"
fi
# detached container should get cleaned up by teardown_ec2_linux
# TODO: Stop building test binaries as part of the build phase
# Used for GPU_FLAG, SHM_OPTS and JENKINS_USER since that doesn't play nice
# Used for GPU_FLAG since that doesn't play nice
# shellcheck disable=SC2086,SC2090
container_name=$(docker run \
${GPU_FLAG:-} \
@ -333,11 +316,11 @@ jobs:
--security-opt seccomp=unconfined \
--cap-add=SYS_PTRACE \
--ipc=host \
${SHM_OPTS} \
--shm-size="${SHM_SIZE}" \
--tty \
--detach \
--name="${container_name}" \
${JENKINS_USER} \
--user jenkins \
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
-w /var/lib/jenkins/workspace \
"${DOCKER_IMAGE}"
@ -345,11 +328,6 @@ jobs:
# Propagate download.pytorch.org IP to container
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" sudo bash -c "/bin/cat >> /etc/hosts"
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then
docker exec -t "${container_name}" sh -c "python3 -m pip install -r .ci/docker/requirements-ci.txt"
fi
docker exec -t "${container_name}" sh -c "python3 -m pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}"
- name: Upload pytest cache if tests failed
@ -489,12 +467,3 @@ jobs:
echo "NVIDIA driver detects $GPU_COUNT GPUs. The runner has a broken GPU, shutting it down..."
.github/scripts/stop_runner_service.sh
fi
- name: Cleanup docker
if: always() && inputs.build-environment == 'linux-s390x-binary-manywheel'
shell: bash
run: |
# on s390x stop the container for clean worker stop
# ignore expansion of "docker ps -q" since it could be empty
# shellcheck disable=SC2046
docker stop $(docker ps -q) || true

View File

@ -152,7 +152,6 @@ jobs:
set -e
${CONDA_RUN} python3 test/run_test.py --mps --verbose
MTL_CAPTURE_ENABLED=1 ${CONDA_RUN} python3 test/test_mps.py --verbose -k test_metal_capture
- name: Print remaining test logs
shell: bash

View File

@ -28,11 +28,6 @@ on:
required: false
type: boolean
default: false
cppwrapper:
description: Run inductor_cpp_wrapper?
required: false
type: boolean
default: false
cudagraphs:
description: Run inductor_cudagraphs?
required: false
@ -43,6 +38,11 @@ on:
required: false
type: boolean
default: false
freeze_autotune_cudagraphs:
description: Run inductor_cudagraphs with freezing and max autotune for inference?
required: false
type: boolean
default: false
aotinductor:
description: Run aot_inductor for inference?
required: false
@ -111,7 +111,7 @@ jobs:
if: github.event.schedule == '0 7 * * 1-6'
with:
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
use-gha: anything-non-empty-to-use-gha
@ -127,7 +127,7 @@ jobs:
if: github.event.schedule == '0 7 * * 0'
with:
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
use-gha: anything-non-empty-to-use-gha
@ -143,7 +143,7 @@ jobs:
if: github.event_name == 'workflow_dispatch'
with:
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-false-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
use-gha: anything-non-empty-to-use-gha

View File

@ -29,13 +29,13 @@ jobs:
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
linux-focal-rocm6_3-py3_10-inductor-build:
name: rocm6.3-py3.10-inductor
linux-focal-rocm6_2-py3_10-inductor-build:
name: rocm6.2-py3.10-inductor
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-rocm6.3-py3.10
build-environment: linux-focal-rocm6.2-py3.10
docker-image-name: pytorch-linux-focal-rocm-n-py3
test-matrix: |
{ include: [
@ -44,15 +44,15 @@ jobs:
]}
secrets: inherit
linux-focal-rocm6_3-py3_10-inductor-test:
linux-focal-rocm6_2-py3_10-inductor-test:
permissions:
id-token: write
contents: read
name: rocm6.3-py3.10-inductor
name: rocm6.2-py3.10-inductor
uses: ./.github/workflows/_rocm-test.yml
needs: linux-focal-rocm6_3-py3_10-inductor-build
needs: linux-focal-rocm6_2-py3_10-inductor-build
with:
build-environment: linux-focal-rocm6.3-py3.10
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.2-py3.10
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-inductor-build.outputs.test-matrix }}
secrets: inherit

View File

@ -139,13 +139,13 @@ jobs:
test-matrix: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-debug-build.outputs.test-matrix }}
secrets: inherit
linux-focal-rocm6_3-py3_10-build:
name: linux-focal-rocm6.3-py3.10
linux-focal-rocm6_2-py3_10-build:
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-rocm6.3-py3.10
build-environment: linux-focal-rocm6.2-py3.10
docker-image-name: pytorch-linux-focal-rocm-n-py3
test-matrix: |
{ include: [
@ -155,19 +155,19 @@ jobs:
]}
secrets: inherit
linux-focal-rocm6_3-py3_10-test:
linux-focal-rocm6_2-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-focal-rocm6.3-py3.10
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-focal-rocm6_3-py3_10-build
- linux-focal-rocm6_2-py3_10-build
- target-determination
with:
build-environment: linux-focal-rocm6.3-py3.10
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.2-py3.10
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build:

View File

@ -411,15 +411,15 @@ jobs:
]}
secrets: inherit
linux-focal-rocm6_3-py3_10-build:
linux-focal-rocm6_2-py3_10-build:
# don't run build twice on main
if: github.event_name == 'pull_request'
name: linux-focal-rocm6.3-py3.10
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-rocm6.3-py3.10
build-environment: linux-focal-rocm6.2-py3.10
docker-image-name: pytorch-linux-focal-rocm-n-py3
sync-tag: rocm-build
test-matrix: |

View File

@ -26,36 +26,36 @@ jobs:
id-token: write
contents: read
linux-focal-rocm6_3-py3_10-build:
linux-focal-rocm6_2-py3_10-build:
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
name: linux-focal-rocm6.3-py3.10
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-focal-rocm6.3-py3.10
build-environment: linux-focal-rocm6.2-py3.10
docker-image-name: pytorch-linux-focal-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 2, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 3, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 4, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 5, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 6, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
]}
secrets: inherit
linux-focal-rocm6_3-py3_10-test:
linux-focal-rocm6_2-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-focal-rocm6.3-py3.10
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-focal-rocm6_3-py3_10-build
- linux-focal-rocm6_2-py3_10-build
- target-determination
with:
build-environment: linux-focal-rocm6.3-py3.10
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.2-py3.10
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.test-matrix }}
secrets: inherit

View File

@ -1,77 +0,0 @@
name: s390x-periodic
on:
schedule:
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
# Also run less frequently on weekends.
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
push:
tags:
- ciflow/periodic/*
- ciflow/s390/*
branches:
- release/*
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
cancel-in-progress: true
permissions: read-all
jobs:
llm-td:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/llm_td_retrieval.yml
permissions:
id-token: write
contents: read
target-determination:
name: before-test
uses: ./.github/workflows/target_determination.yml
needs: llm-td
permissions:
id-token: write
contents: read
linux-manylinux-2_28-py3-cpu-s390x-build:
if: github.repository_owner == 'pytorch'
name: linux-manylinux-2_28-py3-cpu-s390x
uses: ./.github/workflows/_linux-build.yml
with:
build-environment: linux-s390x-binary-manywheel
docker-image-name: pytorch/manylinuxs390x-builder:cpu-s390x-main
runner: linux.s390x
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 2, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 3, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 4, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 5, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 6, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 7, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 8, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 9, num_shards: 10, runner: "linux.s390x" },
{ config: "default", shard: 10, num_shards: 10, runner: "linux.s390x" },
]}
secrets: inherit
linux-manylinux-2_28-py3-cpu-s390x-test:
permissions:
id-token: write
contents: read
name: linux-manylinux-2_28-py3-cpu-s390x
uses: ./.github/workflows/_linux-test.yml
needs:
- linux-manylinux-2_28-py3-cpu-s390x-build
- target-determination
with:
build-environment: linux-s390x-binary-manywheel
docker-image: pytorch/manylinuxs390x-builder:cpu-s390x-main
test-matrix: ${{ needs.linux-manylinux-2_28-py3-cpu-s390x-build.outputs.test-matrix }}
timeout-minutes: 480
use-gha: "yes"
secrets: inherit

View File

@ -103,13 +103,13 @@ jobs:
test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }}
secrets: inherit
linux-focal-rocm6_3-py3_10-build:
name: linux-focal-rocm6.3-py3.10
linux-focal-rocm6_2-py3_10-build:
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-rocm6.3-py3.10
build-environment: linux-focal-rocm6.2-py3.10
docker-image-name: pytorch-linux-focal-rocm-n-py3
test-matrix: |
{ include: [
@ -118,19 +118,19 @@ jobs:
]}
secrets: inherit
linux-focal-rocm6_3-py3_10-test:
linux-focal-rocm6_2-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-focal-rocm6.3-py3.10
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-focal-rocm6_3-py3_10-build
- linux-focal-rocm6_2-py3_10-build
- target-determination
with:
build-environment: linux-focal-rocm6.3-py3.10
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.2-py3.10
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.test-matrix }}
secrets: inherit
linux-jammy-py3_10-clang15-asan-build:

View File

@ -164,36 +164,36 @@ jobs:
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
secrets: inherit
linux-focal-rocm6_3-py3_10-build:
name: linux-focal-rocm6.3-py3.10
linux-focal-rocm6_2-py3_10-build:
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-focal-rocm6.3-py3.10
build-environment: linux-focal-rocm6.2-py3.10
docker-image-name: pytorch-linux-focal-rocm-n-py3
sync-tag: rocm-build
test-matrix: |
{ include: [
{ config: "default", shard: 1, num_shards: 2, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "default", shard: 2, num_shards: 2, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
{ config: "distributed", shard: 1, num_shards: 1, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.4' || 'linux.rocm.gpu.4' }}" },
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2" },
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2" },
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.4" },
]}
secrets: inherit
linux-focal-rocm6_3-py3_10-test:
linux-focal-rocm6_2-py3_10-test:
permissions:
id-token: write
contents: read
name: linux-focal-rocm6.3-py3.10
name: linux-focal-rocm6.2-py3.10
uses: ./.github/workflows/_rocm-test.yml
needs:
- linux-focal-rocm6_3-py3_10-build
- linux-focal-rocm6_2-py3_10-build
- target-determination
with:
build-environment: linux-focal-rocm6.3-py3.10
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }}
build-environment: linux-focal-rocm6.2-py3.10
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.test-matrix }}
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
secrets: inherit

3
.gitmodules vendored
View File

@ -134,3 +134,6 @@
[submodule "third_party/kleidiai"]
path = third_party/kleidiai
url = https://git.gitlab.arm.com/kleidi/kleidiai.git
[submodule "third_party/flash-attention"]
path = third_party/flash-attention
url = https://github.com/drisspg/flash-attention.git

View File

@ -38,29 +38,26 @@ aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/s
generated_cpu_cpp = [
"aten/src/ATen/RegisterBackendSelect.cpp",
"aten/src/ATen/RegisterCPU_0.cpp",
"aten/src/ATen/RegisterCPU_1.cpp",
"aten/src/ATen/RegisterCPU_2.cpp",
"aten/src/ATen/RegisterCPU_3.cpp",
"aten/src/ATen/RegisterCPU.cpp",
"aten/src/ATen/RegisterFunctionalization_0.cpp",
"aten/src/ATen/RegisterFunctionalization_1.cpp",
"aten/src/ATen/RegisterFunctionalization_2.cpp",
"aten/src/ATen/RegisterFunctionalization_3.cpp",
# "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
"aten/src/ATen/RegisterMkldnnCPU_0.cpp",
"aten/src/ATen/RegisterNestedTensorCPU_0.cpp",
"aten/src/ATen/RegisterQuantizedCPU_0.cpp",
"aten/src/ATen/RegisterSparseCPU_0.cpp",
"aten/src/ATen/RegisterSparseCsrCPU_0.cpp",
"aten/src/ATen/RegisterZeroTensor_0.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd_0.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
"aten/src/ATen/RegisterMeta_0.cpp",
"aten/src/ATen/RegisterSparseMeta_0.cpp",
"aten/src/ATen/RegisterQuantizedMeta_0.cpp",
"aten/src/ATen/RegisterNestedTensorMeta_0.cpp",
"aten/src/ATen/RegisterMkldnnCPU.cpp",
"aten/src/ATen/RegisterNestedTensorCPU.cpp",
"aten/src/ATen/RegisterQuantizedCPU.cpp",
"aten/src/ATen/RegisterSparseCPU.cpp",
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
"aten/src/ATen/RegisterZeroTensor.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterSparseMeta.cpp",
"aten/src/ATen/RegisterQuantizedMeta.cpp",
"aten/src/ATen/RegisterNestedTensorMeta.cpp",
"aten/src/ATen/RegisterSchema.cpp",
"aten/src/ATen/CPUFunctions.h",
"aten/src/ATen/CPUFunctions_inl.h",
@ -100,11 +97,11 @@ generated_cpu_cpp = [
generated_cuda_cpp = [
"aten/src/ATen/CUDAFunctions.h",
"aten/src/ATen/CUDAFunctions_inl.h",
"aten/src/ATen/RegisterCUDA_0.cpp",
"aten/src/ATen/RegisterNestedTensorCUDA_0.cpp",
"aten/src/ATen/RegisterQuantizedCUDA_0.cpp",
"aten/src/ATen/RegisterSparseCUDA_0.cpp",
"aten/src/ATen/RegisterSparseCsrCUDA_0.cpp",
"aten/src/ATen/RegisterCUDA.cpp",
"aten/src/ATen/RegisterNestedTensorCUDA.cpp",
"aten/src/ATen/RegisterQuantizedCUDA.cpp",
"aten/src/ATen/RegisterSparseCUDA.cpp",
"aten/src/ATen/RegisterSparseCsrCUDA.cpp",
]
generate_aten(

View File

@ -876,6 +876,10 @@ cmake_dependent_option(
# feature by default We dont currently document this feature because we don't
# Suspect users building from source will need this
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
add_definitions(-DFLASHATTENTION_DISABLE_SOFTCAP)
add_definitions(-DFLASH_NAMESPACE=pytorch_flash)
# See https://github.com/pytorch/pytorch/issues/121558 for details
add_definitions(-DUNFUSE_FMA)
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
# Eff Attention won't

View File

@ -164,9 +164,18 @@ file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
file(GLOB native_utils_cpp "native/utils/*.cpp")
# flash_attention sources
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
# list(APPEND flash_attention_cuda_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
# list(APPEND flash_attention_cuda_cpp ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp)
# Flash attention C++ sources
file(GLOB flash_attention_cuda_cpp
"${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp"
"native/transformers/cuda/flash_attn/flash_api.cpp"
)
# file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
# file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
# file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
# flash_attention hip sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
@ -481,9 +490,6 @@ if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE AND NOT (MSVC AND CMAKE_SYSTEM_PRO
set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will be built." FORCE)
if(WIN32)
set(SLEEF_BUILD_WITH_LIBM OFF CACHE BOOL "Don't build sleef with libm for Windows." FORCE)
endif()
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)

View File

@ -394,6 +394,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
rocm_fa_preferred_backend = b;
}
bool Context::allowFP16ReductionCuBLAS() const {
return allow_fp16_reduction_cublas;
}
@ -410,14 +411,6 @@ void Context::setAllowBF16ReductionCuBLAS(bool b) {
allow_bf16_reduction_cublas = b;
}
bool Context::allowFP16AccumulationCuBLAS() const {
return allow_fp16_accumulation_cublas;
}
void Context::setAllowFP16AccumulationCuBLAS(bool b) {
allow_fp16_accumulation_cublas = b;
}
bool Context::hasMKL() {
#if AT_MKL_ENABLED()

View File

@ -337,8 +337,6 @@ class TORCH_API Context {
void setAllowFP16ReductionCuBLAS(bool);
bool allowBF16ReductionCuBLAS() const;
void setAllowBF16ReductionCuBLAS(bool);
bool allowFP16AccumulationCuBLAS() const;
void setAllowFP16AccumulationCuBLAS(bool);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
static const std::vector<at::QEngine>& supportedQEngines();
@ -420,7 +418,6 @@ class TORCH_API Context {
bool allow_tf32_cudnn = true;
bool allow_fp16_reduction_cublas = true;
bool allow_bf16_reduction_cublas = true;
bool allow_fp16_accumulation_cublas = false;
bool enabled_mkldnn = true;
bool enabled_nnpack = true;
at::LinalgBackend linalg_preferred_backend =

View File

@ -54,7 +54,6 @@ bool isAccelerator(c10::DeviceType device_type) {
}
}
// NOLINTBEGIN(bugprone-unchecked-optional-access)
c10::DeviceIndex deviceCount() {
const auto device_type = getAccelerator(false);
if (!device_type.has_value()) {
@ -100,6 +99,5 @@ void synchronizeDevice(c10::DeviceIndex device_index) {
// impl.synchronizeDevice should can be safely called from any device
impl.synchronizeDevice(device_index);
}
// NOLINTEND(bugprone-unchecked-optional-access)
} // namespace at::accelerator

View File

@ -86,14 +86,14 @@ TaskThreadPoolBase& _get_intraop_pool() {
#endif // C10_MOBILE
// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
// `fn` will be called with params: task_id.
static void _run_with_pool(const std::function<void(size_t)>& fn, size_t range) {
// `fn` will be called with params: (thread_pool_task_id, task_id).
void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
#ifndef C10_MOBILE
for (const auto i : c10::irange(1, range)) {
_get_intraop_pool().run([fn, i]() { fn(i); });
_get_intraop_pool().run([fn, i]() { fn((int)i, i); });
}
// Run the first task on the current thread directly.
fn(0);
fn(0, 0);
#else
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
@ -102,7 +102,7 @@ static void _run_with_pool(const std::function<void(size_t)>& fn, size_t range)
// PThreadPool::run() is blocking. A std::function [const] reference to
// this lambda cannot go out of scope before PThreadPool::run() returns.
[&fn](const size_t task_id) {
fn(task_id);
fn(0 /* unused */, task_id);
}, range);
#endif // C10_MOBILE
}
@ -113,10 +113,6 @@ struct ParallelRegionGuard {
internal::set_thread_num(task_id);
_set_in_parallel_region(true);
}
ParallelRegionGuard(const ParallelRegionGuard&) = delete;
ParallelRegionGuard(ParallelRegionGuard&&) = delete;
ParallelRegionGuard& operator=(const ParallelRegionGuard&) = delete;
ParallelRegionGuard& operator=(ParallelRegionGuard&&) = delete;
~ParallelRegionGuard() {
_set_in_parallel_region(false);
@ -128,16 +124,16 @@ struct ParallelRegionGuard {
namespace internal {
static std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
int64_t begin, int64_t end, int64_t grain_size) {
if ((end - begin) < grain_size) {
return std::make_tuple(1, std::max((int64_t)0, end - begin));
}
// Choose number of tasks based on grain size and number of threads.
int64_t chunk_size = divup((end - begin), get_num_threads());
size_t chunk_size = divup((end - begin), get_num_threads());
// Make sure each task is at least grain_size size.
chunk_size = std::max(grain_size, chunk_size);
size_t num_tasks = static_cast<size_t>(divup((end - begin), chunk_size));
chunk_size = std::max((size_t)grain_size, chunk_size);
size_t num_tasks = divup((end - begin), chunk_size);
return std::make_tuple(num_tasks, chunk_size);
}
@ -161,12 +157,12 @@ void invoke_parallel(
} state;
auto task = [f, &state, begin, end, chunk_size]
(size_t task_id) {
int64_t local_start = static_cast<int64_t>(begin + task_id * chunk_size);
(int /* unused */, size_t task_id) {
int64_t local_start = begin + task_id * chunk_size;
if (local_start < end) {
int64_t local_end = std::min(end, static_cast<int64_t>(chunk_size + local_start));
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
try {
ParallelRegionGuard guard(static_cast<int>(task_id));
ParallelRegionGuard guard(task_id);
f(local_start, local_end);
} catch (...) {
if (!state.err_flag.test_and_set()) {

View File

@ -656,11 +656,10 @@ struct TORCH_API TensorType : public SharedType {
const auto& shape = sizes();
for (size_t i = 0; i < shape.size(); i++) {
auto const &s = shape[i];
if (!s.has_value()) {
if (!shape[i].has_value()) {
return std::optional<size_t>{};
}
prod *= s.value();
prod *= shape[i].value();
}
return prod;
}
@ -728,11 +727,10 @@ struct TORCH_API TensorType : public SharedType {
TensorTypePtr contiguous() const {
auto cloned = clone();
auto concrete_sizes = sizes().concrete_sizes();
TORCH_INTERNAL_ASSERT(concrete_sizes.has_value());
TORCH_INTERNAL_ASSERT(sizes().concrete_sizes().has_value());
auto strides = computeStrideProps(
*concrete_sizes,
contiguousStridesOf(*concrete_sizes));
*sizes().concrete_sizes(),
contiguousStridesOf(*sizes().concrete_sizes()));
cloned->strides_ = strides;
return cloned;
}
@ -1518,8 +1516,8 @@ struct TORCH_API FunctionType : public NamedType {
FunctionType(torch::jit::Function* function);
std::string annotation_str_impl(
[[maybe_unused]] const TypePrinter& printer = nullptr) const override {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
return name()->qualifiedName();
const auto& n = name().value();
return n.qualifiedName();
}
torch::jit::Function* function_;
};
@ -2135,7 +2133,6 @@ struct MatchTypeReturn {
return !reason_.has_value();
}
const std::string& reason() const {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
return reason_.value();
}
@ -2184,7 +2181,6 @@ struct TORCH_API InterfaceType : public NamedType {
}
std::string str() const override {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
return std::string("InterfaceType<") + name()->name() + ">";
}
@ -2212,7 +2208,6 @@ struct TORCH_API InterfaceType : public NamedType {
std::string annotation_str_impl(
[[maybe_unused]] const TypePrinter& printer = nullptr) const override {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
return name()->qualifiedName();
}

View File

@ -904,7 +904,6 @@ bool ListType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const {
std::string TupleType::str() const {
std::stringstream ss;
if (schema_ && name().has_value()) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
ss << name()->qualifiedName();
} else {
ss << "(";

View File

@ -526,7 +526,12 @@ Vectorized<c10::BFloat16> inline fmadd(
// elements, not the bottom and top half, so they don't seem
// particularly useful here. Ideally we would include dot product in
// the Vectorized interface...
return a * b + c;
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
const auto [c_float_low, c_float_high] = convert_bfloat16_float(c);
return convert_float_bfloat16(
fmadd(a_float_low, b_float_low, c_float_low),
fmadd(a_float_high, b_float_high, c_float_high));
}
template <>
@ -535,7 +540,12 @@ Vectorized<c10::BFloat16> inline fmsub(
const Vectorized<c10::BFloat16>& b,
const Vectorized<c10::BFloat16>& c) {
// See NOTE [BF16 FMA] above.
return a * b - c;
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
const auto [c_float_low, c_float_high] = convert_bfloat16_float(c);
return convert_float_bfloat16(
fmsub(a_float_low, b_float_low, c_float_low),
fmsub(a_float_high, b_float_high, c_float_high));
}
#endif // !defined(C10_MOBILE) && defined(__aarch64__)

View File

@ -572,7 +572,12 @@ Vectorized<c10::Half> inline fmadd(
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(vfmaq_f16(c, a, b));
#else
return a * b + c;
const auto [a_float_low, a_float_high] = convert_half_float(a);
const auto [b_float_low, b_float_high] = convert_half_float(b);
const auto [c_float_low, c_float_high] = convert_half_float(c);
return convert_float_half(
fmadd(a_float_low, b_float_low, c_float_low),
fmadd(a_float_high, b_float_high, c_float_high));
#endif
}
@ -584,7 +589,12 @@ Vectorized<c10::Half> inline fmsub(
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
return Vectorized<c10::Half>(vfmsq_f16(c, a, b));
#else
return a * b - c;
const auto [a_float_low, a_float_high] = convert_half_float(a);
const auto [b_float_low, b_float_high] = convert_half_float(b);
const auto [c_float_low, c_float_high] = convert_half_float(c);
return convert_float_half(
fmsub(a_float_low, b_float_low, c_float_low),
fmsub(a_float_high, b_float_high, c_float_high));
#endif
}
#endif // !defined(C10_MOBILE) && defined(__aarch64__)

View File

@ -284,7 +284,6 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
}
template <typename T>
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
// NOLINTNEXTLINE(bugprone-sizeof-expression)
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value)));
}
};
@ -332,12 +331,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
#ifndef USE_ROCM
at::Half halpha;
at::Half hbeta;
#endif
void * alpha_ptr = &alpha;
void * beta_ptr = &beta;
if constexpr (std::is_same_v<Dtype, double>) {
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
@ -354,16 +347,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
abcType = CUDA_C_32F;
scaleType = CUDA_C_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
#ifndef USE_ROCM
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
computeType = CUBLAS_COMPUTE_16F;
halpha = alpha;
hbeta = beta;
alpha_ptr = &halpha;
beta_ptr = &hbeta;
}
#endif
abcType = CUDA_R_16F;
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abcType = CUDA_R_16BF;
@ -409,7 +392,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
@ -431,12 +414,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
alpha_ptr,
&alpha,
a,
Adesc.descriptor(),
b,
Bdesc.descriptor(),
beta_ptr,
&beta,
c,
Cdesc.descriptor(),
c,
@ -546,11 +529,6 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
BGEMM_CHECK_ARGVALUES(at::Half);
float falpha = alpha;
float fbeta = beta;
at::Half halpha;
at::Half hbeta;
void * alpha_ptr = &falpha;
void * beta_ptr = &fbeta;
auto compute_type = CUDA_R_32F;
#ifdef USE_ROCM
int flag = 0;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
@ -567,20 +545,13 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
0, flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
halpha = alpha;
hbeta = beta;
compute_type = CUDA_R_16F;
alpha_ptr = &halpha;
beta_ptr = &hbeta;
}
if (prop->major >= 5){
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
handle, opa, opb, m, n, k,
alpha_ptr, a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, beta_ptr,
(void*)(&falpha), a, CUDA_R_16F, lda, stridea,
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
c, CUDA_R_16F, ldc, stridec,
num_batches, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} else {
for (const auto i : c10::irange(num_batches)) {
at::cuda::blas::gemm<at::Half>(
@ -895,13 +866,8 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
cublasOperation_t opb = _cublasOpFromChar(transb);
float falpha = alpha;
float fbeta = beta;
at::Half halpha;
at::Half hbeta;
void * alpha_ptr = &falpha;
void * beta_ptr = &fbeta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
auto compute_type = CUDA_R_32F;
#ifdef USE_ROCM
int flag = 0;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
@ -934,18 +900,13 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
flag)));
#else
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
compute_type = CUDA_R_16F;
halpha = alpha;
hbeta = beta;
alpha_ptr = &halpha;
beta_ptr = &hbeta;
}
if (prop->major >= 5) {
#ifndef USE_ROCM
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
}
#endif
// Disallow fp16 reductions that could lead to unexpected overflow issues.
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
TORCH_CUDABLAS_CHECK(cublasGemmEx(
@ -955,18 +916,18 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
m,
n,
k,
alpha_ptr,
&falpha,
a,
CUDA_R_16F,
lda,
b,
CUDA_R_16F,
ldb,
beta_ptr,
&fbeta,
c,
CUDA_R_16F,
ldc,
compute_type,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
} else {
@ -1268,12 +1229,6 @@ void gemm_and_bias(
cudaDataType_t abcType = CUDA_R_32F;
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t scaleType = CUDA_R_32F;
void * alpha_ptr = &alpha_val;
void * beta_ptr = &beta_val;
#ifndef USE_ROCM
at::Half halpha_val;
at::Half hbeta_val;
#endif
if constexpr (std::is_same_v<Dtype, double>) {
abcType = CUDA_R_64F;
computeType = CUBLAS_COMPUTE_64F;
@ -1284,17 +1239,6 @@ void gemm_and_bias(
}
abcType = CUDA_R_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
#ifndef USE_ROCM
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
computeType = CUBLAS_COMPUTE_16F;
scaleType = CUDA_R_16F;
halpha_val = alpha_val;
hbeta_val = beta_val;
alpha_ptr = &halpha_val;
beta_ptr = &hbeta_val;
}
#endif
abcType = CUDA_R_16F;
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
abcType = CUDA_R_16BF;
@ -1340,7 +1284,7 @@ void gemm_and_bias(
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
#endif
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
cublasLtMatmulHeuristicResult_t heuristicResult = {};
int returnedResult = 0;
@ -1363,12 +1307,12 @@ void gemm_and_bias(
cublasStatus_t cublasStatus = cublasLtMatmul(
ltHandle,
computeDesc.descriptor(),
alpha_ptr,
&alpha_val,
mat1_ptr,
Adesc.descriptor(),
mat2_ptr,
Bdesc.descriptor(),
beta_ptr,
&beta_val,
result_ptr,
Cdesc.descriptor(),
result_ptr,
@ -1522,7 +1466,7 @@ void scaled_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
}
size_t workspaceSize = _getWorkspaceSize();
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
CuBlasLtMatmulPreference preference;
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);

View File

@ -56,6 +56,7 @@ cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type) {
}
}
#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch_offset, bool is_const=false) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == kStrided);
IntArrayRef input_strides = input.strides();
@ -120,6 +121,7 @@ CuSparseDnMatDescriptor::CuSparseDnMatDescriptor(const Tensor& input, int64_t ba
CuSparseConstDnMatDescriptor::CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset) {
descriptor_.reset(createRawDnMatDescriptor(input, batch_offset, /*is_const*/true));
}
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
CuSparseDnVecDescriptor::CuSparseDnVecDescriptor(const Tensor& input) {
// cuSPARSE doesn't support batched vectors

View File

@ -116,7 +116,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
virtual void recordMemoryHistory(
const std::optional<std::string>& enabled,
std::optional<std::string> enabled,
const std::string& stacks,
size_t max_entries) const {
FAIL_MTIAHOOKS_FUNC(__func__);

View File

@ -162,7 +162,6 @@ grid_sample_backward_helper_in(
static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
grid_sample_backward_helper_out(
// NOLINTNEXTLINE(performance-unnecessary-value-param)
std::tuple<Tensor, Tensor> bw_out,
int64_t grad_input_out_bdim,
int64_t grad_grid_out_bdim,
@ -262,7 +261,7 @@ struct UpsampleBackwardBatchRuleHelper<F, Func, typelist<A, B, C, T...>> {
auto out = Func(
std::move(grad_output_),
output_size,
std::move(output_size),
std::move(physical_input_size),
std::forward<T>(extra_args)...);
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)

View File

@ -16,7 +16,6 @@
// registered to FuncTorchVmapMode. This is because we need to interpose on
// random operations even if they're not on a BatchedTensor.
// NOLINTBEGIN(bugprone-unchecked-optional-access)
namespace at::functorch {
template <typename F, F Func, typename... ExtraArgs>
@ -502,4 +501,3 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
}
} // namespace at::functorch
// NOLINTEND(bugprone-unchecked-optional-access)

View File

@ -11,7 +11,6 @@
#include <utility>
// NOLINTBEGIN(bugprone-unchecked-optional-access)
namespace at::functorch {
static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
@ -511,4 +510,3 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
}
} // namespace at::functorch
// NOLINTEND(bugprone-unchecked-optional-access)

View File

@ -14,7 +14,6 @@
#include <torch/library.h>
// NOLINTBEGIN(bugprone-unchecked-optional-access)
namespace at::functorch {
namespace {
@ -1284,4 +1283,3 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
}
} // namespace at::functorch
// NOLINTEND(bugprone-unchecked-optional-access)

View File

@ -156,7 +156,6 @@ const Tensor& resize__plumbing(
"resize_: batching rule only supports None or Contiguous MemoryFormat");
auto maybe_layer = maybeCurrentDynamicLayer();
vmap_check_escaped(maybe_layer, "resize__plumbing");
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(self, cur_level)) {
c10::impl::ExcludeDispatchKeyGuard guard2(DispatchKey::FuncTorchBatched);

View File

@ -41,7 +41,6 @@ DynamicLayer::DynamicLayer(
}
switch (transform_type) {
case TransformType::Vmap:
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
interpreter_ = Interpreter::Vmap(layerId, std::move(batchSize.value()), randomness.value());
break;
case TransformType::Grad:
@ -51,7 +50,6 @@ DynamicLayer::DynamicLayer(
interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value());
break;
case TransformType::Functionalize:
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
interpreter_ = Interpreter::Functionalize(layerId, functionalize_add_back_views.value());
break;
default:
@ -347,7 +345,9 @@ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int6
if (!ivalue.isTensor()) {
continue;
}
args[idx] = func(ivalue.toTensor(), flag);
Tensor value = ivalue.toTensor();
Tensor replacement = func(value, flag);
args[idx] = std::move(replacement);
// sanity checks
if (ivalue.toTensor().defined()) {
TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());

View File

@ -118,7 +118,6 @@ static Tensor moveDimToFrontAndExpand(Tensor tensor, std::optional<int64_t> dim,
// to `batch_sizes`
VmapPhysicalViewVec
MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
auto cur_level = maybeCurrentDynamicLayer().value().layerId();
c10::SymInt bdim_size = -1;

View File

@ -16,10 +16,6 @@
#include <unordered_map>
#include <utility>
#ifndef __OBJC__
typedef void* MTLCaptureManager;
#endif
namespace at::mps {
namespace Profiler {
@ -62,7 +58,24 @@ struct BaseInfo {
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
static std::string buildTensorString(
const Tensor& tensor,
bool includeBufferId = false);
bool includeBufferId = false) {
if (tensor.defined()) {
std::stringstream tensorStr;
auto deviceType = tensor.device().type();
tensorStr << c10::DeviceTypeName(deviceType);
// see comments for INCLUDE_BUFFER_ID
if (includeBufferId && deviceType == at::kMPS) {
id<MTLBuffer> buffer =
__builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":"
<< buffer.retainCount << ")";
}
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
return tensorStr.str();
} else {
return "undefined";
}
}
static uint64_t getTime() {
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
}

View File

@ -30,23 +30,6 @@ const std::string BaseInfo::toString(double gpuTime, double schedulingTime) cons
schedulingTime > 0.0 ? fmt::format(", cpu={:.3f} ms", schedulingTime) : "");
}
std::string BaseInfo::buildTensorString(const Tensor& tensor, bool includeBufferId) {
if (tensor.defined()) {
std::stringstream tensorStr;
auto deviceType = tensor.device().type();
tensorStr << c10::DeviceTypeName(deviceType);
// see comments for INCLUDE_BUFFER_ID
if (includeBufferId && deviceType == at::kMPS) {
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":" << buffer.retainCount << ")";
}
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
return tensorStr.str();
} else {
return "undefined";
}
}
const std::string OperationInfo::toString(double gpuTime, double schedulingTime) const {
return fmt::format("aten::{} (id={}{}, run={}{})",
strKey,

View File

@ -15,26 +15,21 @@
#include <Metal/Metal.h>
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
typedef MPSCommandBuffer* MPSCommandBuffer_t;
typedef id<MTLCommandQueue> MTLCommandQueue_t;
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
typedef id<MTLSharedEvent> MTLSharedEvent_t;
typedef id<MTLDevice> MTLDevice_t;
typedef id<MTLBuffer> MTLBuffer_t;
#else
#include <dispatch/dispatch.h>
typedef void* MPSCommandBuffer_t;
typedef void* MPSGraph;
typedef void* MPSGraphExecutionDescriptor;
typedef void* MPSGraphCompilationDescriptor;
typedef void* MTLCommandQueue_t;
typedef void* MTLCommandQueue;
typedef void* MTLCommandBuffer_t;
typedef void* MTLCommandBuffer;
typedef void* MTLComputeCommandEncoder_t;
typedef void* MTLSharedEvent_t;
typedef void* dispatch_queue_t;
typedef void* MTLDevice_t;
typedef void* MTLBuffer_t;
typedef void* MTLCommandBufferHandler;
typedef void* NSDictionary;
#define nil NULL
#define nil NULL;
#endif
namespace at::mps {
@ -60,29 +55,27 @@ class TORCH_API MPSStream {
explicit MPSStream(Stream stream);
~MPSStream();
MTLCommandQueue_t commandQueue() const {
return _commandQueue;
}
};
dispatch_queue_t queue() const {
return _serialQueue;
}
MPSCommandBuffer_t commandBuffer();
MPSCommandBuffer* commandBuffer();
MTLComputeCommandEncoder_t commandEncoder();
void endKernelCoalescing();
void synchronize(SyncType syncType);
void fill(MTLBuffer_t buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
void copy(MTLBuffer_t srcBuffer,
MTLBuffer_t dstBuffer,
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
void copy(id<MTLBuffer> srcBuffer,
id<MTLBuffer> dstBuffer,
size_t length,
size_t srcOffset,
size_t dstOffset,
uint64_t profileId,
SyncType syncType = SyncType::NONE);
void copy_and_sync(MTLBuffer_t srcBuffer,
MTLBuffer_t dstBuffer,
void copy_and_sync(id<MTLBuffer> srcBuffer,
id<MTLBuffer> dstBuffer,
size_t length,
size_t srcOffset,
size_t dstOffset,
@ -101,9 +94,11 @@ class TORCH_API MPSStream {
MTLCommandQueue_t stream() const {
return _commandQueue;
}
};
MTLDevice_t device() const;
MTLDevice_t device() const {
return [_commandQueue device];
}
/// Explicit conversion to Stream.
Stream unwrap() const {
@ -113,8 +108,8 @@ class TORCH_API MPSStream {
private:
Stream _stream;
MTLCommandQueue_t _commandQueue = nil;
MPSCommandBuffer_t _commandBuffer = nil;
MPSCommandBuffer_t _prevCommandBuffer = nil;
MPSCommandBuffer* _commandBuffer = nil;
MPSCommandBuffer* _prevCommandBuffer = nil;
MTLComputeCommandEncoder_t _commandEncoder = nil;
MPSGraphExecutionDescriptor* _executionDescriptor = nil;
MPSGraphCompilationDescriptor* _compilationDescriptor = nil;

View File

@ -51,10 +51,6 @@ MPSCommandBuffer* MPSStream::commandBuffer() {
return _commandBuffer;
}
id<MTLDevice> MPSStream::device() const {
return [_commandQueue device];
}
id<MTLComputeCommandEncoder> MPSStream::commandEncoder() {
if (!_commandEncoder) {
_commandEncoder = [commandBuffer() computeCommandEncoder].retain;

View File

@ -579,7 +579,7 @@ static void _rrelu_with_noise_train(
Tensor& noise,
const Scalar& lower_,
const Scalar& upper_,
const std::optional<Generator>& generator) {
std::optional<Generator> generator) {
using opmath_t = at::opmath_type<scalar_t>;
opmath_t lower = lower_.to<opmath_t>();
opmath_t upper = upper_.to<opmath_t>();

View File

@ -946,8 +946,6 @@ inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) {
return dnnl::memory::data_type::bf16;
} else if (dtype == ScalarType::Half) {
return dnnl::memory::data_type::f16;
} else if (dtype == ScalarType::Int) {
return dnnl::memory::data_type::s32;
} else if (dtype == ScalarType::Byte) {
return dnnl::memory::data_type::u8;
} else if (dtype == ScalarType::Char) {
@ -1093,7 +1091,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
M,
N,
K,
int64_t(1),
1,
ld_a,
ld_b,
ld_c,
@ -1133,12 +1131,6 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
} else if (dtype == ScalarType::BFloat16) {
static bool bf16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core;
return bf16_support;
} else if (dtype == ScalarType::Byte) {
static bool u8_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx;
return u8_support;
} else if (dtype == ScalarType::Char) {
static bool s8_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_vnni;
return s8_support;
}
return false;
}
@ -1189,9 +1181,6 @@ struct Pack : public KernelCache <PackKey, pack_t> {
} else if (dtype == ScalarType::BFloat16) {
static bool bf16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx;
return bf16_pack;
} else if (dtype == ScalarType::Byte || dtype == ScalarType::Char) {
static bool bit8_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx;
return bit8_pack;
}
return false;
}
@ -1293,54 +1282,6 @@ void brgemm(
beta, C, ld_c);
}
void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const bool add_C,
const unsigned char* A,
const unsigned char* B,
int32_t* C,
bool is_vnni) {
#if defined(ONEDNN_UKERNEL_ENABLED)
if (is_vnni && Brgemm::device_check(ScalarType::Byte)) {
Brgemm::call<unsigned char, unsigned char, int32_t>(
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
return;
}
#endif
// raise an error if the path is not supported
TORCH_CHECK(false,
"U8 Brgemm is only supported on X64 when oneDNN ukernel is enabled and `amx` is supported");
}
void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const bool add_C,
const unsigned char* A,
const signed char* B,
int32_t* C,
bool is_vnni) {
#if defined(ONEDNN_UKERNEL_ENABLED)
if (is_vnni && Brgemm::device_check(ScalarType::Char)) {
Brgemm::call<unsigned char, signed char, int32_t>(
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
return;
}
#endif
// raise an error if the path is not supported
TORCH_CHECK(false,
"I8 Brgemm is only supported on X64 when oneDNN ukernel is enabled and `amx` is supported");
}
void brgemm_release(bool is_vnni) {
#if defined(ONEDNN_UKERNEL_ENABLED)
if (is_vnni) {

View File

@ -233,37 +233,11 @@ TORCH_API void brgemm(
float* C,
bool is_vnni = false);
TORCH_API void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const bool add_C,
const unsigned char* A,
const unsigned char* B,
int32_t* C,
bool is_vnni = true);
TORCH_API void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const bool add_C,
const unsigned char* A,
const signed char* B,
int32_t* C,
bool is_vnni = true);
// Release brgemm hardware context
TORCH_API void brgemm_release(bool is_vnni = true);
// Pack B matrix to get better performance if needed
TORCH_API void pack(
void pack(
int64_t K,
int64_t N,
int64_t ld_in,

View File

@ -620,7 +620,11 @@ Tensor _conj_physical(const Tensor& self) {
if (self.is_conj()) {
return self.conj().clone();
}
auto result = at::empty_like(self);
auto options = self.options();
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
options = options.dtype(c10::get_default_dtype());
}
auto result = at::empty_like(self, options);
return at::conj_physical_out(result, self);
}

View File

@ -1,21 +1,21 @@
#pragma once
#include <ATen/cpu/vec/vec.h>
#include <c10/util/BFloat16.h> // For c10::is_reduced_floating_point_v.
#include <c10/util/BFloat16.h> // For std::is_reduced_floating_point_v.
namespace at::native {
constexpr double kGeluBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
constexpr double kGeluKappa = 0.044715;
template <typename T>
using reduced_fp_to_float_t = std::conditional_t<c10::is_reduced_floating_point_v<T>, float, T>;
using reduced_fp_to_float_t = std::conditional_t<std::is_reduced_floating_point_v<T>, float, T>;
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
template <typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
float reduced_fp_to_float(T x) {
return float(x);
}
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
template <typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
T reduced_fp_to_float(T x) {
return x;
}
@ -29,7 +29,7 @@ T scalar_gelu_approximated_with_tanh(T x) {
return opmath_t(0.5) * x_float * (opmath_t(1) + std::tanh(inner));
}
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
template <typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x) {
const vec::Vectorized<T> kPointFiveVec(T(0.5));
const vec::Vectorized<T> kOneVec(T(1));
@ -40,7 +40,7 @@ vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x)
return kPointFiveVec * x * (kOneVec + inner_vec.tanh());
}
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
template <typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x) {
auto [x0, x1] = at::vec::convert_to_float<T>(x);
return at::vec::convert_from_float<T>(
@ -56,7 +56,7 @@ T scalar_gelu(T x) {
return reduced_fp_to_float(x) * opmath_t(0.5) * (opmath_t(1) + std::erf(reduced_fp_to_float(x) * kAlpha));
}
template<typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
template<typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
const vec::Vectorized<T> kAlphaVec(T(M_SQRT1_2));
const vec::Vectorized<T> kOneVec(T(1));
@ -64,7 +64,7 @@ vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
return x * kPointFiveVec * (kOneVec + (x * kAlphaVec).erf());
}
template<typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
template<typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
auto [x0, x1] = at::vec::convert_to_float<T>(x);
return at::vec::convert_from_float<T>(vectorized_gelu(x0), vectorized_gelu(x1));

View File

@ -995,7 +995,7 @@ static void ref_dyn_quant_matmul_4bit_channelwise_kernel(
dst_f32 += 1;
}
}
}
};
static void ref_dyn_quant_matmul_4bit_groupwise_kernel(
size_t m,

View File

@ -964,7 +964,7 @@ ScalingType get_scaling_type(
} // namespace
// Computes matrix multiply + bias while applying scaling to input and output matrices
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
// Known limitations:
// - Only works if mat1 is row-major and mat2 is column-major

View File

@ -19,7 +19,37 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
#if defined(BUILD_ROWWISE_FP8_KERNEL)
#include <cute/tensor.hpp>
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
CUtensorMap* tensorMap,
CUtensorMapDataType tensorDataType,
cuuint32_t tensorRank,
void* globalAddress,
const cuuint64_t* globalDim,
const cuuint64_t* globalStrides,
const cuuint32_t* boxDim,
const cuuint32_t* elementStrides,
CUtensorMapInterleave interleave,
CUtensorMapSwizzle swizzle,
CUtensorMapL2promotion l2Promotion,
CUtensorMapFloatOOBfill oobFill) {
return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
tensorMap,
tensorDataType,
tensorRank,
globalAddress,
globalDim,
globalStrides,
boxDim,
elementStrides,
interleave,
swizzle,
l2Promotion,
oobFill);
}
#include <cutlass/version.h>
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
@ -27,7 +57,16 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
#include <cutlass/numeric_types.h>
#include <cutlass/trace.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/version.h>
// Rename the global function symbol
#if CUTLASS_VERSION == 351
#include <cute/tensor.hpp>
#else
#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
#include <cute/tensor.hpp>
#undef cuTensorMapEncodeTiled
#endif
// Set everything back to normal
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
@ -38,8 +77,6 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
namespace {
@ -148,7 +185,11 @@ void f8f8bf16_rowwise_impl(
// Implement rowwise scaling epilogue.
constexpr int ColBroadcastStages = 0;
#if CUTLASS_VERSION == 351
constexpr int RowBroadcastStages = 0;
#else
constexpr int RowBroadcastStages = PingPong::value ? 2 : 1;
#endif
using XScale = cutlass::epilogue::fusion::
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
@ -164,10 +205,19 @@ void f8f8bf16_rowwise_impl(
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
#if CUTLASS_VERSION == 351
#define FLIPPED_SCALES ;
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
Multiply,
WScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
#else
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
XScale,
cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>;
#endif
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
Cast,
@ -230,6 +280,7 @@ void f8f8bf16_rowwise_impl(
StrideOutput stride_output = cutlass::make_cute_packed_stride(
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
#ifdef FLIPPED_SCALES
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
@ -245,6 +296,23 @@ void f8f8bf16_rowwise_impl(
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output}};
#else
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K},
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
stride_a,
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
stride_b},
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
: nullptr},
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())},
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}}},
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output,
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
stride_output}};
#endif
Gemm gemm;

View File

@ -889,9 +889,9 @@ void lstm_miopen(Tensor& output, Tensor& hy, Tensor& cy,
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
auto result = _miopen_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases,
miopenLSTM, num_layers, dropout_p, train, bidirectional, batch_first);
output = std::move(result.first);
hy = std::move(std::get<0>(result.second));
cy = std::move(std::get<1>(result.second));
output = result.first;
hy = std::get<0>(result.second);
cy = std::get<1>(result.second);
}
void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy,
@ -900,9 +900,9 @@ void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy,
int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
auto result = _miopen_impl(data, batch_sizes, std::make_tuple(hx[0], hx[1]),
params, has_biases, miopenLSTM, num_layers, dropout_p, train, bidirectional);
output = std::move(result.first);
hy = std::move(std::get<0>(result.second));
cy = std::move(std::get<1>(result.second));
output = result.first;
hy = std::get<0>(result.second);
cy = std::get<1>(result.second);
}
REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)

View File

@ -17,8 +17,7 @@
#include <ATen/ops/mm_native.h>
#endif
namespace at::native {
namespace xpu {
namespace at::native::xpu {
// result = beta * self + alpha * (mat1 * mat2)
Tensor& addmm_out(
@ -455,7 +454,7 @@ Tensor& tensordot_out(
TORCH_LIBRARY_IMPL(aten, XPU, m) {
m.impl("tensordot.out", TORCH_FN(tensordot_out));
}
} // namespace xpu
} // namespace at::native::xpu
TORCH_IMPL_FUNC(addmm_out_xpu)
(const Tensor& self,
@ -470,13 +469,11 @@ TORCH_IMPL_FUNC(addmm_out_xpu)
TORCH_IMPL_FUNC(mm_out_xpu)
(const Tensor& self, const Tensor& mat2, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
xpu::mm_out(self, mat2, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(bmm_out_xpu)
(const Tensor& self, const Tensor& batch2, const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
xpu::bmm_out(self, batch2, const_cast<Tensor&>(result));
}
@ -501,13 +498,7 @@ TORCH_IMPL_FUNC(baddbmm_out_xpu)
const Scalar& alpha,
const Tensor& result) {
xpu::baddbmm_out(
self,
batch1,
batch2,
beta,
alpha,
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<Tensor&>(result));
self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result));
}
TORCH_IMPL_FUNC(addmv_out_xpu)
@ -517,8 +508,5 @@ TORCH_IMPL_FUNC(addmv_out_xpu)
const Scalar& beta,
const Scalar& alpha,
const Tensor& result) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
xpu::addmv_out(self, mat, vec, beta, alpha, const_cast<Tensor&>(result));
}
} // namespace at::native

View File

@ -100,6 +100,8 @@ at::Tensor quantized_convolution(
{c10::kXPU, c10::xpu::current_device()});
auto stream = GpuStreamManager::Instance().get_stream();
// create usr_md for tensors, and md for conv primitive
dnnl::memory::desc src_md, weight_md, output_md;
// input tensors config
dnnl::memory::dims src_dims = act.sizes().vec();
dnnl::memory::dims weight_dims = weight.sizes().vec();
@ -128,8 +130,7 @@ at::Tensor quantized_convolution(
bool src_need_zp = (act_scale != 0);
// create usr_md for tensors, and md for conv primitive
auto [src_md, weight_md, output_md] =
std::tie(src_md, weight_md, output_md) =
qconv_get_md(act, weight, output, groups);
// get tensor md

View File

@ -870,12 +870,7 @@ id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
auto device = MPSDevice::getInstance()->device();
library = [device newLibraryWithSource:str options:options error:&error];
if (library == nil) {
if ([error domain] == MTLLibraryErrorDomain && [error code] == MTLLibraryErrorCompileFailure) {
throw c10::SyntaxError([[error localizedDescription] UTF8String]);
}
TORCH_CHECK(false, "Failed to create metal library, error: ", [[error description] UTF8String]);
}
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
return library;
}

View File

@ -420,7 +420,7 @@ TORCH_API int register_conv_params<3>();
int register_linear_params() {
using SerializationType = std::tuple<at::Tensor, std::optional<at::Tensor>>;
[[maybe_unused]] static auto register_linear_params =
static auto register_linear_params =
torch::selective_class_<LinearPackedParamsBase>(
"quantized", TORCH_SELECTIVE_CLASS("LinearPackedParamsBase"))
.def_pickle(
@ -495,7 +495,7 @@ int register_embedding_params() {
std::vector<double>,
std::vector<int64_t>>;
[[maybe_unused]] static auto register_embedding_params =
static auto register_embedding_params =
torch::selective_class_<EmbeddingPackedParamsBase>(
"quantized", TORCH_SELECTIVE_CLASS("EmbeddingPackedParamsBase"))
.def_pickle(

View File

@ -11,6 +11,27 @@
// sparsification, as a bitmask.
// NOTE: Algorithms might select LESS than 8 values in total in some cases.
namespace cutlass::platform {
template <>
struct numeric_limits<cutlass::bfloat16_t> {
CUTLASS_HOST_DEVICE
static cutlass::bfloat16_t infinity() {
return cutlass::bfloat16_t::bitcast(0x7f80);
}
};
#if CUTLASS_VERSION == 341
template <>
struct numeric_limits<cutlass::half_t> {
CUTLASS_HOST_DEVICE
static cutlass::half_t infinity() {
return cutlass::half_t::bitcast(0x7c00);
}
};
#endif
} // namespace cutlass::platform
namespace at::native {
template <typename Element, typename Pointwise>

View File

@ -916,6 +916,7 @@ _flash_attention_forward(
std::optional<Tensor> seqused_k = _seqused_k;
std::optional<at::Tensor> block_table = std::nullopt; // we are not using the block table yet
std::optional<Tensor> alibi_slopes = _alibi_slopes;
const float softcap = 0.0;
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
@ -957,6 +958,7 @@ _flash_attention_forward(
is_causal,
non_null_window_left,
non_null_window_right,
softcap,
return_debug_mask,
std::nullopt /*gen_*/);
} else {
@ -980,6 +982,7 @@ _flash_attention_forward(
is_causal,
non_null_window_left,
non_null_window_right,
softcap,
return_debug_mask, /*return_softmax (this is used for testing)*/
std::nullopt);
}

View File

@ -94,6 +94,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
// Currently unused args:
std::optional<at::Tensor> alibi_slopes{std::nullopt};
const float softcap = 0.0;
bool determinisitic{false};
auto& ctx = at::globalContext();
@ -132,6 +133,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
is_causal,
non_null_window_left,
non_null_window_right,
softcap,
determinisitic,
philox_seed,
philox_offset);
@ -154,6 +156,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
is_causal,
non_null_window_left,
non_null_window_right,
softcap,
determinisitic,
philox_seed,
philox_offset);

View File

@ -1,74 +0,0 @@
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
namespace pytorch_flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_causal>
struct Alibi {
const float alibi_slope;
const int max_seqlen_k, max_seqlen_q;
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope)
, max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) {
};
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}
};
} // namespace pytorch_flash

View File

@ -1,46 +0,0 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
namespace pytorch_flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Varlen=true>
struct BlockInfo {
template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
template <typename index_t>
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}
template <typename index_t>
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}
const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace pytorch_flash

View File

@ -1,96 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
namespace pytorch_flash {
using namespace cute;
struct Dropout {
const unsigned long long seed, offset;
const uint8_t p_dropout_in_uint8_t;
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
const uint8_t p_dropout_in_uint8_t,
const int bid, const int hid, const int tid, const int nheads)
: seed(seed)
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
}
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
};
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
uint2 rowcol = make_uint2(block_row_start, block_col_start);
#pragma unroll
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if (!encode_dropout_in_sign_bit
&& (std::is_same_v<T, cutlass::half_t> || std::is_same_v<T, cutlass::bfloat16_t>)) {
uint16_t rnd_16[16];
#pragma unroll
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
#pragma unroll
for (int j = 0; j < 2; j++) {
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for (int i = 0; i < 4; i++) {
uint32_t mask;
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
tensor_uint32(i) &= mask;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
} else {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
}
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}
};
} // namespace pytorch_flash

View File

@ -1,190 +0,0 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <cuda.h>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
namespace pytorch_flash {
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;
// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;
// The pointer to the P matrix.
void * __restrict__ p_ptr;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;
// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// uint32_t p_dropout_in_uint;
// uint16_t p_dropout_in_uint16_t;
uint8_t p_dropout_in_uint8_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;
// Local window size
int window_size_left, window_size_right;
// Random state.
at::PhiloxCudaState philox_args;
int64_t * extragraph_offset;
int64_t * seed;
bool is_bf16;
bool is_causal;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
bool is_rotary_interleaved;
int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_bwd_params : public Flash_fwd_params {
// The dO and dQKV matrices.
void *__restrict__ do_ptr;
void *__restrict__ dq_ptr;
void *__restrict__ dk_ptr;
void *__restrict__ dv_ptr;
// To accumulate dQ
void *__restrict__ dq_accum_ptr;
void *__restrict__ dk_accum_ptr;
void *__restrict__ dv_accum_ptr;
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
// dv_accum_ptr;
// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
index_t dq_batch_stride;
index_t dk_batch_stride;
index_t dv_batch_stride;
index_t dq_row_stride;
index_t dk_row_stride;
index_t dv_row_stride;
index_t dq_head_stride;
index_t dk_head_stride;
index_t dv_head_stride;
// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
} // namespace pytorch_flash

View File

@ -2,6 +2,7 @@
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include <c10/core/ScalarType.h>
#include <c10/core/DeviceType.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <cstdint>
@ -32,13 +33,18 @@
#include <cutlass/numeric_types.h>
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
// #include <ATen/native/transformers/cuda/flash_attn/flash.h>
#include <flash.h>
#include <namespace_config.h>
#include <static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
// #include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
#include <c10/util/Exception.h>
namespace pytorch_flash {
namespace FLASH_NAMESPACE {
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
@ -70,7 +76,9 @@ void set_params_fprop(Flash_fwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
bool seqlenq_ngroups_swapped=false) {
const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
// Reset the parameters
params = {};
@ -126,8 +134,19 @@ void set_params_fprop(Flash_fwd_params &params,
params.d_rounded = d_rounded;
// Set the different scale values.
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
#endif
if (softcap > 0.0) {
params.softcap = softmax_scale / softcap;
params.scale_softmax = softcap;
params.scale_softmax_log2 = softcap * M_LOG2E;
} else{
// Remove potential NaN
params.softcap = 0.0;
params.scale_softmax = softmax_scale;
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
@ -162,6 +181,8 @@ void set_params_fprop(Flash_fwd_params &params,
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
#endif
params.unpadded_lse = unpadded_lse;
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
}
void set_params_dgrad(Flash_bwd_params &params,
@ -195,7 +216,9 @@ void set_params_dgrad(Flash_bwd_params &params,
float softmax_scale,
int window_size_left,
int window_size_right,
bool deterministic) {
const float softcap,
bool deterministic,
const bool unpadded_lse) {
set_params_fprop(params,
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
@ -208,7 +231,10 @@ void set_params_dgrad(Flash_bwd_params &params,
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
softcap,
false, // seqlenq_ngroups_swapped
unpadded_lse);
// Set the pointers and strides.
params.do_ptr = dout.data_ptr();
@ -244,11 +270,13 @@ void set_params_dgrad(Flash_bwd_params &params,
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
} else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
}
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
} else {
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
}
});
});
});
}
@ -357,6 +385,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
@ -396,6 +425,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
if (window_size_left >= seqlen_k) { window_size_left = -1; }
if (window_size_right >= seqlen_k) { window_size_right = -1; }
@ -441,7 +472,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
@ -476,11 +507,14 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
p_dropout,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
softcap
);
// Keep references to these tensors to extend their lifetime
auto [softmax_lse_accum, out_accum] = set_params_splitkv(params, batch_size, num_heads,
at::Tensor softmax_lse_accum, out_accum;
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
head_size, seqlen_k, seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
@ -497,26 +531,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
} else {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
params.seed = seed_t.data_ptr<int64_t>();
params.extragraph_offset = offset_t.data_ptr<int64_t>();
}
seed_t = at::empty({2}, at::TensorOptions().dtype(c10::kUInt64).device(at::kCUDA));
params.rng_state = reinterpret_cast<uint64_t*>(seed_t.data_ptr());
params.philox_args = philox_state;
} else {
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
} else {
seed_t = at::empty({}, at::dtype(at::kLong));
offset_t = at::empty({}, at::dtype(at::kLong));
}
}else{
seed_t = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
@ -556,6 +576,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
@ -604,6 +625,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
const int head_size_og = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
@ -667,7 +690,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
@ -679,7 +701,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
@ -689,7 +711,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
auto opts = q.options();
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
auto softmax_lse = at::empty({num_heads, total_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time
if (return_softmax) {
@ -720,7 +742,10 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
softmax_scale,
window_size_left,
window_size_right,
seqlenq_ngroups_swapped);
softcap,
seqlenq_ngroups_swapped,
/*unpadded_lse*/true);
params.total_q = total_q;
if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
@ -740,9 +765,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
// We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will
// be used in the backward function
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::Tensor seed_t, offset_t;
if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
@ -750,26 +775,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
} else {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
params.seed = seed_t.data_ptr<int64_t>();
params.extragraph_offset = offset_t.data_ptr<int64_t>();
}
seed_t = at::empty({2}, at::TensorOptions().dtype(c10::kUInt64).device(at::kCUDA));
params.rng_state = reinterpret_cast<uint64_t*>(seed_t.data_ptr());
params.philox_args = philox_state;
} else {
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
} else {
seed_t = at::empty({}, at::dtype(at::kLong));
offset_t = at::empty({}, at::dtype(at::kLong));
}
seed_t = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
}
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
@ -788,7 +799,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
}
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
@ -797,7 +808,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
FP16_SWITCH(!params.is_bf16, [&] {
HEADDIM_SWITCH(params.d, [&] {
run_mha_bwd_<elem_type, kHeadDim>(params, stream);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
});
});
});
}
@ -818,6 +831,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
@ -877,7 +891,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
@ -976,21 +990,17 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
softmax_scale,
window_size_left,
window_size_right,
deterministic);
softcap,
deterministic,
/*unpadded_lse*/false);
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd;
at::PhiloxCudaState philox_args;
if (is_dropout) {
if (at::cuda::currentStreamCaptureStatus() ==
at::cuda::CaptureStatus::None)
{
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
} else { // dropout + capture
philox_args = at::PhiloxCudaState(
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
}
params.rng_state = philox_seed.data_ptr<uint64_t>();
}
params.philox_args = philox_args;
@ -1019,7 +1029,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &out, // total_q x num_heads x head_size
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@ -1034,6 +1044,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset)
@ -1099,7 +1110,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
@ -1154,7 +1165,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
auto softmax_d = at::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
at::Tensor dq_accum;
if (loop) {
// We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
@ -1165,6 +1176,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
// Same holds for softmax_d, since LSE is stored in unpadded format.
if (!deterministic) {
dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
} else {
@ -1210,21 +1222,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_scale,
window_size_left,
window_size_right,
deterministic);
softcap,
deterministic,
/*unpadded_lse*/true);
params.total_q = total_q;;
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
auto launch = &run_mha_bwd;
at::PhiloxCudaState philox_args;
if (is_dropout) {
if (at::cuda::currentStreamCaptureStatus() ==
at::cuda::CaptureStatus::None)
{
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
} else { // dropout + capture
philox_args = at::PhiloxCudaState(
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
}
params.rng_state = philox_seed.data_ptr<uint64_t>();
}
params.philox_args = philox_args;
@ -1265,6 +1273,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
) {
@ -1375,7 +1384,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
@ -1403,7 +1412,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
/*p_dropout=*/0.f,
softmax_scale,
window_size_left,
window_size_right);
window_size_right,
softcap
);
at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
@ -1486,7 +1497,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
}
// Keep references to these tensors to extend their lifetime
auto [softmax_lse_accum, out_accum] = set_params_splitkv(params, batch_size, num_heads,
at::Tensor softmax_lse_accum, out_accum;
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
head_size, seqlen_k, seqlen_q,
head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);

View File

@ -1,10 +1,11 @@
#pragma once
#include <cstddef>
#include <namespace_config.h>
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>
namespace pytorch_flash {
namespace FLASH_NAMESPACE {
TORCH_API
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -18,6 +19,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);
@ -39,6 +41,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);
@ -59,6 +62,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset);
@ -84,8 +88,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset);
} // namespace pytorch_flash
} // namespace FLASH_NAMESPACE

View File

@ -1,827 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/PhiloxUtils.cuh>
#include <cute/algorithm/copy.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
#include <ATen/native/transformers/cuda/flash_attn/softmax.h>
#include <ATen/native/transformers/cuda/flash_attn/mask.h>
#include <ATen/native/transformers/cuda/flash_attn/dropout.h>
#include <ATen/native/transformers/cuda/flash_attn/alibi.h>
namespace pytorch_flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_N,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
// Divide by 2 because right now we always use 2 for the ValLayout
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
// This gives the correct layout, idk why.
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
// Stride<Stride<_1, _64>, _8> >{},
// auto t = make_tile(Layout<Shape<_8, _2, _2>,
// Stride<_1, _64, _8> >{},
auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
make_layout(Int<TileShape_K>{}));
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_N,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
// Divide by 2 because right now we always use 2 for the ValLayout
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
auto t = make_tile(make_layout(Int<TileShape_M>{}),
Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value;
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
if (Is_local) {
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
}
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
+ (m_block_max - 1) * kBlockM;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
+ (m_block_max - 1) * kBlockM;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutQdO{});
Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
// Double buffer for sQ
Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{});
Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(),
typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{});
Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{});
Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK),
typename Kernel_traits::SmemLayoutPdS{});
Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{});
Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
// sP and sdQ share the same memory so be careful
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
using GmemTiledCopydO = std::conditional_t<
Is_first,
typename Kernel_traits::GmemTiledCopydO,
typename Kernel_traits::GmemTiledCopyQKV
>;
GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
using GmemLayoutAtomdQaccum = std::conditional_t<
!Seq_parallel,
typename Kernel_traits::GmemTiledCopydQaccum,
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd
>;
GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO);
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
// if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
// __syncthreads();
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) {
// printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data());
// }
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K)
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K)
Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K)
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);
Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N)
Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N)
Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N)
Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx);
Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N)
Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N)
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
//
// Copy Atom retiling
//
auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);
// auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx);
auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_KV.partition_S(sK);
// if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);
// Partition sP and sdS to match the accumulator partitioning
// This has to be tiled_mma_sdp, not tiled_mma_dkv
// auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); }
// if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); }
// if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) {
// printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data());
// }
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx);
Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);
Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);
auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);
Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);
auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq);
auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx);
Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);
auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq);
auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);
Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
//
// PREDICATES
//
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ);
Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV);
// Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
// Set predicates for k bounds
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
#pragma unroll
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
}
// Prologue
// We'll advance gdQ and gdQaccum before the 1st read/write.
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
int m_block = m_block_max - 1;
int m_block_min = (!Is_causal && !Is_local)
? 0
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
// If not local, we're guaranteed that m_block_min <= m_block:
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
// However, if local, then this possible to have some blocks of K & V not attending to any query.
// We might need to exit early and write 0 to dK and dV for those blocks.
// Otherwise we get wrong result for the case where we don't enter the for loop.
// And we might read OOB elements from gQ and gdO.
// This also covers the case where actual_seqlen_q == 0
if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
clear(tdKrdK);
clear(tdVrdV);
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
}
if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
tQsQ.data() = tQsQ.data() + size(sQ);
tSsQ.data() = tSsQ.data() + size(sQ);
tdKsQt.data() = tdKsQt.data() + size(sQ);
}
if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }
if (Kernel_traits::Is_V_in_regs) {
// Clear the smem tiles to account for predicated off loads
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
pytorch_flash::cp_async_fence();
}
Tensor tdOrdO = make_fragment_like(tdOgdO);
Tensor tdOrO = make_fragment_like(tdOgO);
if (!Is_first) {
// Clear the smem tiles to account for predicated off loads
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
} else {
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
}
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N)
static_assert(decltype(size<0>(taccScS))::value == 4);
// Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices.
Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
}
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
// Tensor tKrK = make_fragment_like(tKsK);
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); }
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
if (!Kernel_traits::Is_V_in_regs) {
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
pytorch_flash::cp_async_fence();
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
if (Is_first) {
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
}
if (Kernel_traits::Is_V_in_regs) {
cute::cp_async_wait<1>();
__syncthreads();
Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV);
CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
}
const auto [seed, offset] = at::cuda::philox::unpack(params.philox_args);
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
bidb, bidh, tidx, params.h);
clear(acc_dv);
clear(acc_dk);
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
pytorch_flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
for (; m_block >= m_block_min; --m_block) {
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
clear(acc_s);
cute::cp_async_wait<0>();
__syncthreads();
Tensor dP_sum = make_fragment_like(lse);
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
// if (cute::thread0()) { print(sK); }
// Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK);
// #pragma unroll
// for (int k = 0; k < size<2>(tSrK_copy_view); ++k) {
// cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
// }
// if (cute::thread0()) { print(tSrK); }
pytorch_flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread(32, 0)) { print(scores); }
if (Has_alibi) {
alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
}
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
// so the result would still be correct.
// However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k.
if (!Is_causal && !Is_local) {
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
}
} else if (Is_causal) {
// Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements beyond actual_seqlen_k.
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
pytorch_flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
} else if (Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
pytorch_flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q, AtomLayoutMS * 16,
params.window_size_left, params.window_size_right);
}
}
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
pytorch_flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if constexpr (Is_dropout) {
int warp_id = tidx / 32;
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
static_assert(MMA_N_SdP % 2 == 0);
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
acc_s, block_row_idx, block_col_idx, AtomLayoutMS
);
}
// Convert scores from fp32 to fp16/bf16
Tensor rP = !Is_dropout
? pytorch_flash::convert_type<Element>(acc_s)
: pytorch_flash::convert_type_relu<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
// if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaSdP>(rP.layout()));
Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
// if (cute::thread0()) { print(tPaP); }
// __syncthreads();
// if (cute::thread0()) { print(sP); }
Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA
CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA
clear(acc_dp);
// Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), pytorch_flash::convert_layout_acc_rowcol(acc_dp.layout()));
// #pragma unroll
// for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) {
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) {
// acc_dp_reshaped(mi, ni) = -dP_sum(mi);
// }
// }
// if (cute::thread0()) { print(dP_sum); }
pytorch_flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
);
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
auto pointwise_mult = [](float p, float dp, float d) {
return p * (!Is_dropout || p >= 0 ? dp - d : d);
};
#pragma unroll
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
}
}
// if (cute::thread0()) { print(dS); }
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded));
if (Is_first || Seq_parallel) {
clear(acc_dq);
} else {
// Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum
Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),
make_layout(get<0>(acc_dq.layout()),
get<2>(acc_dq.layout()),
get<1>(acc_dq.layout())));
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped);
}
if (Double_buffer && m_block > m_block_min) {
// Double buffer for sQ
const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);
tQsQ.data() = tQsQ.data() + sQ_offset;
tSsQ.data() = tSsQ.data() + sQ_offset;
// Advance gQ
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
pytorch_flash::cp_async_fence();
}
Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
// Convert dS from fp32 to fp16
Tensor tdSrdS = pytorch_flash::convert_type<Element>(dS_reshaped);
// if (cute::thread0()) { print(tPrP); }
Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
__syncthreads();
// Layout p_l = tPrP.layout();
// Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
// pytorch_flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// pytorch_flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
pytorch_flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
// if (cute::thread0()) { print(acc_dv); }
__syncthreads(); // Need syncthreads since we're writing to the same sdO location
if (m_block > m_block_min) {
// Advance gdO
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
if (Is_first) {
tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
} else {
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);
pytorch_flash::cp_async_fence();
}
}
pytorch_flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
// if (cute::thread0()) { print(acc_dq); }
if (m_block > m_block_min) {
gLSE.data() = gLSE.data() + (-int(kBlockM));
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
}
if (!Is_last) {
// Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum
Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),
make_layout(get<0>(acc_dq.layout()),
get<2>(acc_dq.layout()),
get<1>(acc_dq.layout())));
if (!Seq_parallel) {
cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum);
} else {
// if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); }
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
}
} else {
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dq from fp32 to fp16
Tensor rdQ = pytorch_flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
}
pytorch_flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
// if (cute::thread0()) { print(acc_dk); }
if (Double_buffer) { // Double buffer for sQ
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
}
if (!Double_buffer && m_block > m_block_min) {
__syncthreads();
// Advance gQ
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
pytorch_flash::cp_async_fence();
}
if (Is_first && m_block > m_block_min) {
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
}
if (Is_last) {
__syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride));
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
#pragma unroll
for (int m = 0; m < size<1>(tdQgdQ); ++m) {
if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _));
}
}
}
}
// Epilogue
if (Is_dropout) {
#pragma unroll
for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; }
}
#pragma unroll
for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dv from fp32 to fp16
Tensor rdK = pytorch_flash::convert_type<Element>(acc_dk);
Tensor rdV = pytorch_flash::convert_type<Element>(acc_dv);
Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
// Partition sdV and sdK to match the accumulator partitioning
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// We need syncthreads here since we're writing to the same location as sK and sV.
// Without syncthreads, some thread might modify the location of sK while another thread
// is reading it for dQ gemm, leading to a race condition.
// If Is_last, there's already a __syncthreads() at the end of the loop.
if (!Is_last) { __syncthreads(); }
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
__syncthreads();
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.x;
// const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.y;
// const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
if (n_block_max == 1) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
} else {
// Iterating backward from n_block_max - 1 to 0 might save 1 register
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
}
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -1,338 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <c10/cuda/CUDAException.h>
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h>
namespace pytorch_flash {
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#endif
#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
pytorch_flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
pytorch_flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
template<bool Clear_dQaccum=true, typename Kernel_traits>
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
pytorch_flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
pytorch_flash::clear_dKVaccum<Kernel_traits>(params);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
pytorch_flash::convert_dQ<Kernel_traits>(params, nsplits);
}
template<typename Kernel_traits>
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
pytorch_flash::convert_dKV<Kernel_traits>(params);
}
template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid_m(num_m_block, params.b, params.h);
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
int gridDimx = num_n_block;
if (params.deterministic) {
auto dprops = at::cuda::getCurrentDeviceProperties();
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
}
dim3 grid_n(gridDimx, params.b, params.h);
if (!params.deterministic) {
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
} else {
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
// a multiple of kBlockN, we'll need to apply mask in the loop.
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
}
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<typename Kernel_traits, bool Is_dropout>
void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
#ifndef FLASHATTENTION_DISABLE_BACKWARD
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
#endif
}
template<typename T>
void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
}
} else { // 96 KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
}
});
}
template<typename T>
void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else {
// if (params.h == params.h_k) {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
// } else {
// }
}
});
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
}
template<typename T>
void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) {
if constexpr(!Is_dropout) { // 92KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
} else { // 116 KB
// This is faster for dropout since we don't have many registers to spare
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
}
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
}
});
}
template<typename T>
void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
if (max_smem_per_block >= 144 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
} else {
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
}
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
});
}
template<typename T>
void run_mha_bwd_hdim160(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 116 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
}
});
}
template<typename T>
void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 136 * 1024) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
}
});
}
template<typename T>
void run_mha_bwd_hdim224(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
});
}
template<typename T>
void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if constexpr (!Is_dropout) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
}
}
});
}
}; // namespace pytorch_flash

View File

@ -1,377 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
namespace pytorch_flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
// Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
// The last coordinate is the "page".
Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
make_layout(get<0>(do_.layout()),
get<2>(do_.layout()))));
Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
Tensor do_fp32 = pytorch_flash::convert_type<float>(do_reshaped);
Tensor o_fp32 = pytorch_flash::convert_type<float>(o_reshaped);
#pragma unroll
for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
#pragma unroll
for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
}
pytorch_flash::SumOp<float> sum_op;
dP_sum_cur = pytorch_flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
if (threadIdx.x % THREADS_PER_ROW == 0) {
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
inline __device__ void compute_dot_do_o(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.do_row_stride, _1{}));
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
// TODO: careful, we're zeroing out dQaccum with type float4, but when
// we do atomicAdds, we use type float. The layouts are different. Check this.
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
// Allocate predicate tensors for k
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
// Set predicates for k bounds
#pragma unroll
for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
Tensor tdOrdO = make_fragment_like(tdOgdO);
Tensor tdOrO = make_fragment_like(tdOgO);
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
);
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
);
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
// so that (dP - dP_sum) is on the same scale.
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
if (Clear_dQaccum) {
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
// do atomicAdds on.
Tensor zero = make_fragment_like(tdQgdQaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void clear_dKVaccum(const Params &params) {
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
Tensor zero = make_fragment_like(tdKgdKaccum);
clear(zero);
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert dQ from dQaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template<typename Kernel_traits, typename Params>
inline __device__ void convert_dQ(const Params &params, const int nsplits) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
const int m_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.dq_row_stride, _1{}));
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.h * params.d_rounded, _1{}));
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdQ{});
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
clear(acc_dq);
for (int s = 0; s < nsplits; ++s) {
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
}
#pragma unroll
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
// Convert acc_dq from fp32 to fp16
Tensor rdQ = pytorch_flash::convert_type<Element>(acc_dq);
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
__syncthreads();
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
#pragma unroll
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
// This is used in the case where we want to parallelize the backward across seqlen_q.
template<typename Kernel_traits, typename Params>
inline __device__ void convert_dKV(const Params &params) {
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
// Shared memory.
extern __shared__ char smem_[];
const int n_block = blockIdx.x;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.z;
// The thread index.
const int tidx = threadIdx.x;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
const BlockInfo binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
+ n_block * kBlockN) * params.d_rounded;
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dk_row_stride, _1{}));
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.dv_row_stride, _1{}));
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
Stride<Int<kHeadDim>, _1>{});
Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
typename Kernel_traits::SmemLayoutdKV{});
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
#pragma unroll
for (int i = 0; i < size(acc_dk); ++i) {
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
}
#pragma unroll
for (int i = 0; i < size(acc_dv); ++i) {
acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
}
// Convert acc_dk from fp32 to fp16
Tensor rdK = pytorch_flash::convert_type<Element>(acc_dk);
Tensor rdV = pytorch_flash::convert_type<Element>(acc_dv);
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
__syncthreads();
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
} // namespace flash

View File

@ -1,378 +0,0 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContextLight.h>
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h>
namespace pytorch_flash {
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define ARCH_SUPPORTS_FLASH
#endif
#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
#define KERNEL_PARAM_MODIFIER __grid_constant__
#else
#define KERNEL_PARAM_MODIFIER
#endif
// Define a macro for unsupported architecture handling to centralize the error message
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
// Use a macro to clean up kernel definitions
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
template<typename Kernel_traits, __VA_ARGS__> \
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // Enforce constraints
pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
#if defined(ARCH_SUPPORTS_FLASH)
pytorch_flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
static_assert(Log_max_splits >= 1);
pytorch_flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
constexpr size_t smem_size = Kernel_traits::kSmemSize;
// printf("smem_size = %d\n", smem_size);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// Will only return softmax if dropout, to reduce compilation time.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
}
template<typename Kernel_traits>
void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
constexpr size_t smem_size = Kernel_traits::kSmemSize;
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
});
});
});
if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}
template<typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}
template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 64;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// Using block size (64 x 256) is 27% slower for seqlen=2k
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 96;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// These two are always slower
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 128;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 1st ones are good for H100, A100
// 2nd one is good for A6000 bc we get slightly better occupancy
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
}
});
});
}
template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 160;
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, H100, 128 x 32 is the fastest.
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
// and 128 x 64 with 8 warps is the fastest for non-causal.
if (is_sm8x) {
if constexpr(!Is_causal) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 224;
int device;
cudaGetDevice(&device);
int max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
// If we have N = 32, there are only 1024 elements to load at once, where each load
// is 8 elements. This means we can only use 128 threads and not 256 threads.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 256;
int device;
cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block;
cudaError status_ = cudaDeviceGetAttribute(
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
status_ = cudaDeviceGetAttribute(
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
// For A100, we want to run with 128 x 64 (128KB smem).
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
}
// 64 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
// 96 KB
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
});
});
}
}; // namespace pytorch_flash

View File

@ -1,344 +0,0 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/algorithm/copy.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
#include <cutlass/numeric_types.h>
namespace pytorch_flash{
using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
struct Flash_kernel_traits {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
static constexpr bool Has_cp_async = true;
#else
using Element = cutlass::half_t;
static constexpr bool Has_cp_async = false;
#endif
using ElementAccum = float;
using index_t = int64_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_fwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
Tile<Int<16 * kNWarps>, _16, _16>>;
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
using SmemLayoutVtransposed = decltype(
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
using GmemTiledCopyRotcossin = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressure.
// No_double_buffer is another option to reduce smem usage, but will slow things down.
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_bwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Is_V_in_regs = Is_V_in_regs_;
static constexpr bool No_double_buffer = No_double_buffer_;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
static_assert(kNWarps % AtomLayoutMSdP == 0);
static_assert(kNWarps % AtomLayoutNdKV == 0);
static_assert(kNWarps % AtomLayoutMdQ == 0);
using TiledMmaSdP = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
using TiledMmadKV = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
using TiledMmadQ = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
using SmemLayoutAtomQdO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQdO = decltype(tile_to_shape(
SmemLayoutAtomQdO{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
// SmemLayoutAtomQdO{},
SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutKtransposed = decltype(
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN;
// Temporarily disabling this for hdim 256 on sm86 and sm89
// static_assert(kBlockN >= 64);
static_assert(kBlockN >= 32);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static constexpr int kSwizzlePdS = 3;
using SmemLayoutAtomPdS = decltype(
composition(Swizzle<kSwizzlePdS, 3, 3>{},
Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
Stride<Int<kPBlockN>, _1>>{}));
using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemLayoutPdStransposed = decltype(
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutQdOtransposed = decltype(
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutdKV = decltype(tile_to_shape(
SmemLayoutAtomdKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomdQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutdQ = decltype(tile_to_shape(
SmemLayoutAtomdQ{},
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
// Double buffer for sQ
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
static constexpr int kSmemSize1colblock = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
// to affect speed in practice.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopydO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydKV = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemTiledCopydQ = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomdQaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopydQaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomdQaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemTiledCopydQaccumAtomicAdd = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
Stride<_32, _1>>{},
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim224<cutlass::half_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
}
} // namespace pytorch_flash

View File

@ -1,14 +0,0 @@
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
namespace pytorch_flash{
template<>
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params &params, cudaStream_t stream) {
run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
}
} // namespace pytorch_flash

Some files were not shown because too many files have changed in this diff Show More