mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 07:27:32 +08:00
Compare commits
1 Commits
sanchitint
...
Update-Fla
| Author | SHA1 | Date | |
|---|---|---|---|
| a5e8b0ad38 |
@ -3,9 +3,6 @@ set -eux -o pipefail
|
||||
|
||||
GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
|
||||
|
||||
# cuda arm build for Grace Hopper solely
|
||||
export TORCH_CUDA_ARCH_LIST="9.0"
|
||||
|
||||
SCRIPTPATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
|
||||
source $SCRIPTPATH/aarch64_ci_setup.sh
|
||||
|
||||
|
||||
5
.ci/docker/aotriton_version.txt
Normal file
5
.ci/docker/aotriton_version.txt
Normal file
@ -0,0 +1,5 @@
|
||||
0.8b
|
||||
manylinux_2_28
|
||||
rocm6.2
|
||||
6f8cbcac8a92775291bb1ba8f514d4beb350baf4
|
||||
e938def5d32869fe2e00aec0300f354c9f157867bebdf2e104d732b94cb238d8
|
||||
@ -268,7 +268,7 @@ case "$image" in
|
||||
PROTOBUF=yes
|
||||
DB=yes
|
||||
VISION=yes
|
||||
ROCM_VERSION=6.2.4
|
||||
ROCM_VERSION=6.1
|
||||
NINJA_VERSION=1.9.0
|
||||
CONDA_CMAKE=yes
|
||||
TRITON=yes
|
||||
@ -279,7 +279,7 @@ case "$image" in
|
||||
PROTOBUF=yes
|
||||
DB=yes
|
||||
VISION=yes
|
||||
ROCM_VERSION=6.3
|
||||
ROCM_VERSION=6.2.4
|
||||
NINJA_VERSION=1.9.0
|
||||
CONDA_CMAKE=yes
|
||||
TRITON=yes
|
||||
@ -497,7 +497,7 @@ docker build \
|
||||
--build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \
|
||||
--build-arg "KATEX=${KATEX:-}" \
|
||||
--build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \
|
||||
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a;gfx942}" \
|
||||
--build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH:-gfx90a}" \
|
||||
--build-arg "IMAGE_NAME=${IMAGE_NAME}" \
|
||||
--build-arg "UCX_COMMIT=${UCX_COMMIT}" \
|
||||
--build-arg "UCC_COMMIT=${UCC_COMMIT}" \
|
||||
|
||||
@ -113,6 +113,13 @@ COPY triton_version.txt triton_version.txt
|
||||
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
|
||||
RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt
|
||||
|
||||
# Install AOTriton (Early fail)
|
||||
COPY ./aotriton_version.txt aotriton_version.txt
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./common/install_aotriton.sh install_aotriton.sh
|
||||
RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"]
|
||||
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
|
||||
|
||||
# Install ccache/sccache (do this last, so we get priority in PATH)
|
||||
COPY ./common/install_cache.sh install_cache.sh
|
||||
ENV PATH /opt/cache/bin:$PATH
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
set -euo pipefail
|
||||
|
||||
readonly version=v24.04
|
||||
readonly src_host=https://github.com/ARM-software
|
||||
readonly src_host=https://review.mlplatform.org/ml
|
||||
readonly src_repo=ComputeLibrary
|
||||
|
||||
# Clone ACL
|
||||
|
||||
23
.ci/docker/common/install_aotriton.sh
Executable file
23
.ci/docker/common/install_aotriton.sh
Executable file
@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
||||
|
||||
TARBALL='aotriton.tar.gz'
|
||||
# This read command alwasy returns with exit code 1
|
||||
read -d "\n" VER MANYLINUX ROCMBASE PINNED_COMMIT SHA256 < aotriton_version.txt || true
|
||||
ARCH=$(uname -m)
|
||||
AOTRITON_INSTALL_PREFIX="$1"
|
||||
AOTRITON_URL="https://github.com/ROCm/aotriton/releases/download/${VER}/aotriton-${VER}-${MANYLINUX}_${ARCH}-${ROCMBASE}-shared.tar.gz"
|
||||
|
||||
cd "${AOTRITON_INSTALL_PREFIX}"
|
||||
# Must use -L to follow redirects
|
||||
curl -L --retry 3 -o "${TARBALL}" "${AOTRITON_URL}"
|
||||
ACTUAL_SHA256=$(sha256sum "${TARBALL}" | cut -d " " -f 1)
|
||||
if [ "${SHA256}" != "${ACTUAL_SHA256}" ]; then
|
||||
echo -n "Error: The SHA256 of downloaded tarball is ${ACTUAL_SHA256},"
|
||||
echo " which does not match the expected value ${SHA256}."
|
||||
exit
|
||||
fi
|
||||
tar xf "${TARBALL}" && rm -rf "${TARBALL}"
|
||||
@ -62,22 +62,6 @@ install_ubuntu() {
|
||||
sqlite3 $kdb "PRAGMA journal_mode=off; PRAGMA VACUUM;"
|
||||
done
|
||||
|
||||
# ROCm 6.3 had a regression where initializing static code objects had significant overhead
|
||||
if [[ $(ver $ROCM_VERSION) -eq $(ver 6.3) ]]; then
|
||||
# clr build needs CppHeaderParser but can only find it using conda's python
|
||||
/opt/conda/bin/python -m pip install CppHeaderParser
|
||||
git clone https://github.com/ROCm/HIP -b rocm-6.3.x
|
||||
HIP_COMMON_DIR=$(readlink -f HIP)
|
||||
git clone https://github.com/jeffdaily/clr -b release/rocm-rel-6.3-statco-hotfix
|
||||
mkdir -p clr/build
|
||||
pushd clr/build
|
||||
cmake .. -DCLR_BUILD_HIP=ON -DHIP_COMMON_DIR=$HIP_COMMON_DIR
|
||||
make -j
|
||||
cp hipamd/lib/libamdhip64.so.6.3.* /opt/rocm/lib/libamdhip64.so.6.3.*
|
||||
popd
|
||||
rm -rf HIP clr
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
apt-get autoclean && apt-get clean
|
||||
rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
|
||||
|
||||
@ -92,6 +92,13 @@ RUN apt-get update -y && \
|
||||
RUN bash ./install_rocm_drm.sh && rm install_rocm_drm.sh
|
||||
RUN bash ./install_rocm_magma.sh && rm install_rocm_magma.sh
|
||||
|
||||
# Install AOTriton
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./aotriton_version.txt aotriton_version.txt
|
||||
COPY ./common/install_aotriton.sh install_aotriton.sh
|
||||
RUN bash ./install_aotriton.sh /opt/rocm && rm install_aotriton.sh aotriton_version.txt
|
||||
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
|
||||
|
||||
FROM ${BASE_TARGET} as final
|
||||
COPY --from=openssl /opt/openssl /opt/openssl
|
||||
# Install patchelf
|
||||
|
||||
@ -198,3 +198,10 @@ ADD ./common/install_rocm_magma.sh install_rocm_magma.sh
|
||||
RUN bash ./install_rocm_magma.sh && rm install_rocm_magma.sh
|
||||
ADD ./common/install_miopen.sh install_miopen.sh
|
||||
RUN bash ./install_miopen.sh ${ROCM_VERSION} && rm install_miopen.sh
|
||||
|
||||
# Install AOTriton
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./aotriton_version.txt aotriton_version.txt
|
||||
COPY ./common/install_aotriton.sh install_aotriton.sh
|
||||
RUN bash ./install_aotriton.sh /opt/rocm && rm install_aotriton.sh aotriton_version.txt
|
||||
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
|
||||
|
||||
@ -304,7 +304,7 @@ pytest-cpp==2.3.0
|
||||
#Pinned versions: 2.3.0
|
||||
#test that import:
|
||||
|
||||
z3-solver==4.12.6.0
|
||||
z3-solver==4.12.2.0
|
||||
#Description: The Z3 Theorem Prover Project
|
||||
#Pinned versions:
|
||||
#test that import:
|
||||
|
||||
@ -107,6 +107,13 @@ COPY triton_version.txt triton_version.txt
|
||||
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
|
||||
RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt
|
||||
|
||||
# Install AOTriton
|
||||
COPY ./aotriton_version.txt aotriton_version.txt
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ./common/install_aotriton.sh install_aotriton.sh
|
||||
RUN ["/bin/bash", "-c", "./install_aotriton.sh /opt/rocm && rm -rf install_aotriton.sh aotriton_version.txt common_utils.sh"]
|
||||
ENV AOTRITON_INSTALLED_PREFIX /opt/rocm/aotriton
|
||||
|
||||
# This is needed by sccache
|
||||
COPY ./common/install_openssl.sh install_openssl.sh
|
||||
ENV OPENSSL_ROOT_DIR /opt/openssl
|
||||
|
||||
@ -53,10 +53,22 @@ cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.')
|
||||
TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6"
|
||||
case ${CUDA_VERSION} in
|
||||
12.6)
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0+PTX"
|
||||
if [[ "$GPU_ARCH_TYPE" = "cuda-aarch64" ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="9.0"
|
||||
else
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0+PTX"
|
||||
fi
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
||||
;;
|
||||
12.4)
|
||||
if [[ "$GPU_ARCH_TYPE" = "cuda-aarch64" ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="9.0"
|
||||
else
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
|
||||
fi
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
||||
;;
|
||||
12.1)
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0"
|
||||
EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON")
|
||||
;;
|
||||
|
||||
@ -186,6 +186,15 @@ do
|
||||
OS_SO_FILES[${#OS_SO_FILES[@]}]=$file_name # Append lib to array
|
||||
done
|
||||
|
||||
# FIXME: Temporary until https://github.com/pytorch/pytorch/pull/137443 lands
|
||||
# Install AOTriton
|
||||
if [ -e ${PYTORCH_ROOT}/.ci/docker/aotriton_version.txt ]; then
|
||||
cp -a ${PYTORCH_ROOT}/.ci/docker/aotriton_version.txt aotriton_version.txt
|
||||
bash ${PYTORCH_ROOT}/.ci/docker/common/install_aotriton.sh ${ROCM_HOME} && rm aotriton_version.txt
|
||||
export AOTRITON_INSTALLED_PREFIX=${ROCM_HOME}/aotriton
|
||||
ROCM_SO_FILES+=("libaotriton_v2.so")
|
||||
fi
|
||||
|
||||
# rocBLAS library files
|
||||
ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library
|
||||
ROCBLAS_LIB_DST=lib/rocblas/library
|
||||
@ -257,6 +266,20 @@ RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC))
|
||||
DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/})
|
||||
DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/})
|
||||
|
||||
# PyTorch 2.6+ (AOTriton 0.8b+)
|
||||
# AKS = "AOTriton Kernel Storage", a file format to store GPU kernels compactly
|
||||
if (( $(echo "${PYTORCH_VERSION} 2.6" | awk '{print ($1 >= $2)}') )); then
|
||||
LIBAOTRITON_DIR=$(find "$ROCM_HOME/lib/" -name "libaotriton_v2.so" -printf '%h\n')
|
||||
if [[ -z ${LIBAOTRITON_DIR} ]]; then
|
||||
LIBAOTRITON_DIR=$(find "$ROCM_HOME/" -name "libaotriton_v2.so" -printf '%h\n')
|
||||
fi
|
||||
AKS_FILES=($(find "${LIBAOTRITON_DIR}/aotriton.images" -type f -name '*.aks?' -printf '%P\n'))
|
||||
AKS_SRC="${LIBAOTRITON_DIR}/aotriton.images"
|
||||
AKS_DST="lib/aotriton.images"
|
||||
DEPS_AUX_SRCLIST+=(${AKS_FILES[@]/#/${AKS_SRC}/})
|
||||
DEPS_AUX_DSTLIST+=(${AKS_FILES[@]/#/${AKS_DST}/})
|
||||
fi
|
||||
|
||||
echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}"
|
||||
|
||||
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
|
||||
|
||||
@ -228,7 +228,7 @@ if [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then
|
||||
export CMAKE_BUILD_TYPE=RelWithAssert
|
||||
fi
|
||||
|
||||
# Do not change workspace permissions for ROCm and s390x CI jobs
|
||||
# Do not change workspace permissions for ROCm CI jobs
|
||||
# as it can leave workspace with bad permissions for cancelled jobs
|
||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /var/lib/jenkins/workspace ]]; then
|
||||
# Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96)
|
||||
|
||||
@ -12,9 +12,9 @@ export TERM=vt100
|
||||
# shellcheck source=./common.sh
|
||||
source "$(dirname "${BASH_SOURCE[0]}")/common.sh"
|
||||
|
||||
# Do not change workspace permissions for ROCm and s390x CI jobs
|
||||
# Do not change workspace permissions for ROCm CI jobs
|
||||
# as it can leave workspace with bad permissions for cancelled jobs
|
||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /var/lib/jenkins/workspace ]]; then
|
||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* && -d /var/lib/jenkins/workspace ]]; then
|
||||
# Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96)
|
||||
WORKSPACE_ORIGINAL_OWNER_ID=$(stat -c '%u' "/var/lib/jenkins/workspace")
|
||||
cleanup_workspace() {
|
||||
@ -86,13 +86,6 @@ if [[ "$BUILD_ENVIRONMENT" == *clang9* || "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
export VALGRIND=OFF
|
||||
fi
|
||||
|
||||
|
||||
if [[ "$BUILD_ENVIRONMENT" == *s390x* ]]; then
|
||||
# There are additional warnings on s390x, maybe due to newer gcc.
|
||||
# Skip this check for now
|
||||
export VALGRIND=OFF
|
||||
fi
|
||||
|
||||
if [[ "${PYTORCH_TEST_RERUN_DISABLED_TESTS}" == "1" ]] || [[ "${CONTINUE_THROUGH_ERROR}" == "1" ]]; then
|
||||
# When rerunning disable tests, do not generate core dumps as it could consume
|
||||
# the runner disk space when crashed tests are run multiple times. Running out
|
||||
@ -541,7 +534,7 @@ test_perf_for_dashboard() {
|
||||
--dynamic-batch-only "$@" \
|
||||
--output "$TEST_REPORTS_DIR/${backend}_dynamic_${suite}_${dtype}_${mode}_${device}_${target}.csv"
|
||||
fi
|
||||
if [[ "$DASHBOARD_TAG" == *cppwrapper-true* ]]; then
|
||||
if [[ "$DASHBOARD_TAG" == *cppwrapper-true* ]] && [[ "$mode" == "inference" ]]; then
|
||||
TORCHINDUCTOR_CPP_WRAPPER=1 $TASKSET python "benchmarks/dynamo/$suite.py" \
|
||||
"${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --disable-cudagraphs "$@" \
|
||||
--output "$TEST_REPORTS_DIR/${backend}_cpp_wrapper_${suite}_${dtype}_${mode}_${device}_${target}.csv"
|
||||
@ -917,20 +910,10 @@ test_libtorch_api() {
|
||||
else
|
||||
# Exclude IMethodTest that relies on torch::deploy, which will instead be ran in test_deploy
|
||||
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="${MNIST_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_api -k "not IMethodTest"
|
||||
|
||||
# On s390x, pytorch is built without llvm.
|
||||
# Even if it would be built with llvm, llvm currently doesn't support used features on s390x and
|
||||
# test fails with errors like:
|
||||
# JIT session error: Unsupported target machine architecture in ELF object pytorch-jitted-objectbuffer
|
||||
# unknown file: Failure
|
||||
# C++ exception with description "valOrErr INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/tensorexpr/llvm_jit.h":34, please report a bug to PyTorch. Unexpected failure in LLVM JIT: Failed to materialize symbols: { (main, { func }) }
|
||||
if [[ "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
|
||||
python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
|
||||
fi
|
||||
python test/run_test.py --cpp --verbose -i cpp/test_tensorexpr
|
||||
fi
|
||||
|
||||
# quantization is not fully supported on s390x yet
|
||||
if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* && "${BUILD_ENVIRONMENT}" != *asan* && "${BUILD_ENVIRONMENT}" != *s390x* ]]; then
|
||||
if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* && "${BUILD_ENVIRONMENT}" != *asan* ]]; then
|
||||
# NB: This test is not under TORCH_BIN_DIR but under BUILD_BIN_DIR
|
||||
export CPP_TESTS_DIR="${BUILD_BIN_DIR}"
|
||||
python test/run_test.py --cpp --verbose -i cpp/static_runtime_test
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
---
|
||||
# NOTE there must be no spaces before the '-', so put the comma last.
|
||||
# The check bugprone-unchecked-optional-access is also turned on.
|
||||
# Note that it can cause clang-tidy to hang randomly. The tracking issue
|
||||
# The check bugprone-unchecked-optional-access is also turned off atm
|
||||
# because it causes clang-tidy to hang randomly. The tracking issue
|
||||
# can be found at https://github.com/llvm/llvm-project/issues/69369.
|
||||
# When that happens, we can disable it on the problematic code by NOLINT.
|
||||
InheritParentConfig: true
|
||||
Checks: '
|
||||
bugprone-*,
|
||||
@ -13,6 +12,7 @@ bugprone-*,
|
||||
-bugprone-lambda-function-name,
|
||||
-bugprone-reserved-identifier,
|
||||
-bugprone-swapped-arguments,
|
||||
-bugprone-unchecked-optional-access,
|
||||
clang-analyzer-core.*,
|
||||
clang-analyzer-cplusplus.*,
|
||||
clang-analyzer-nullability.*,
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -5,7 +5,7 @@ body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/pytorch/pytorch/issues?q=is%3Aissue+sort%3Acreated-desc+). Note: Please write your bug report in English to ensure it can be understood and addressed by the development team.
|
||||
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/pytorch/pytorch/issues?q=is%3Aissue+sort%3Acreated-desc+).
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 🐛 Describe the bug
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/documentation.yml
vendored
4
.github/ISSUE_TEMPLATE/documentation.yml
vendored
@ -2,10 +2,6 @@ name: 📚 Documentation
|
||||
description: Report an issue related to https://pytorch.org/docs/stable/index.html
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Note: Please report your documentation issue in English to ensure it can be understood and addressed by the development team.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 📚 The doc issue
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
4
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
@ -2,10 +2,6 @@ name: 🚀 Feature request
|
||||
description: Submit a proposal/request for a new PyTorch feature
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Note: Please write your feature request in English to ensure it can be understood and addressed by the development team.
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 🚀 The feature, motivation and pitch
|
||||
|
||||
4
.github/ISSUE_TEMPLATE/pt2-bug-report.yml
vendored
4
.github/ISSUE_TEMPLATE/pt2-bug-report.yml
vendored
@ -3,10 +3,6 @@ description: Create a report to help us reproduce and fix the bug
|
||||
labels: ["oncall: pt2"]
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Note: Please write your bug report in English to ensure it can be understood and addressed by the development team.
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
|
||||
4
.github/actions/diskspace-cleanup/action.yml
vendored
4
.github/actions/diskspace-cleanup/action.yml
vendored
@ -17,10 +17,6 @@ runs:
|
||||
set -ex
|
||||
diskspace_cutoff=${{ inputs.diskspace-cutoff }}
|
||||
docker_root_dir=$(docker info -f '{{.DockerRootDir}}')
|
||||
if [ ! -d "$docker_root_dir" ]; then
|
||||
echo "Docker root directory ($docker_root_dir) does not exist. Skipping disk space check."
|
||||
exit 0
|
||||
fi
|
||||
diskspace=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //')
|
||||
msg="Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified"
|
||||
if [[ "$diskspace" -ge "$diskspace_cutoff" ]] ; then
|
||||
|
||||
18
.github/actions/setup-rocm/action.yml
vendored
18
.github/actions/setup-rocm/action.yml
vendored
@ -5,6 +5,20 @@ description: Set up ROCm host for CI
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Set DOCKER_HOST
|
||||
shell: bash
|
||||
run: echo "DOCKER_HOST=unix:///run/user/$(id -u)/docker.sock" >> "${GITHUB_ENV}"
|
||||
|
||||
- name: Remove leftover Docker config file
|
||||
shell: bash
|
||||
continue-on-error: true
|
||||
run: |
|
||||
set -ex
|
||||
|
||||
cat ~/.docker/config.json || true
|
||||
# https://stackoverflow.com/questions/64455468/error-when-logging-into-ecr-with-docker-login-error-saving-credentials-not
|
||||
rm -f ~/.docker/config.json
|
||||
|
||||
- name: Stop all running docker containers
|
||||
if: always()
|
||||
shell: bash
|
||||
@ -97,10 +111,8 @@ runs:
|
||||
shell: bash
|
||||
run: |
|
||||
# All GPUs are visible to the runner; visibility, if needed, will be set by run_test.py.
|
||||
# Add render group for container creation.
|
||||
render_gid=`cat /etc/group | grep render | cut -d: -f3`
|
||||
# The --group-add daemon and --group-add bin are needed in the Ubuntu 24.04 and Almalinux OSs respectively.
|
||||
# This is due to the device files (/dev/kfd & /dev/dri) being owned by video group on bare metal.
|
||||
# This video group ID maps to subgid 1 inside the docker image due to the /etc/subgid entries.
|
||||
# The group name corresponding to group ID 1 can change depending on the OS, so both are necessary.
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device /dev/dri --group-add video --group-add $render_gid --group-add daemon --group-add bin --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --network=host" >> "${GITHUB_ENV}"
|
||||
echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd --device=/dev/dri --group-add video --group-add daemon --group-add bin" >> "${GITHUB_ENV}"
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
# Self-Hosted IBM Z Github Actions Runner.
|
||||
|
||||
# Temporary image: amd64 dependencies.
|
||||
FROM --platform=linux/amd64 docker.io/ubuntu:24.04 as ld-prefix
|
||||
FROM docker.io/amd64/ubuntu:23.10 as ld-prefix
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && apt-get -y install ca-certificates libicu74 libssl3
|
||||
RUN apt-get update && apt-get -y install ca-certificates libicu72 libssl3
|
||||
|
||||
# Main image.
|
||||
FROM --platform=linux/s390x docker.io/ubuntu:24.04
|
||||
FROM docker.io/s390x/ubuntu:23.10
|
||||
|
||||
# Packages for pytorch building and testing.
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
15
.github/workflows/_linux-build.yml
vendored
15
.github/workflows/_linux-build.yml
vendored
@ -219,10 +219,6 @@ jobs:
|
||||
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then
|
||||
JENKINS_USER=
|
||||
USED_IMAGE="${DOCKER_IMAGE_S390X}"
|
||||
# ensure that docker container cleanly exits in 12 hours
|
||||
# if for some reason cleanup action doesn't stop container
|
||||
# when job is cancelled
|
||||
DOCKER_SHELL_CMD="sleep 12h"
|
||||
|
||||
# since some steps are skipped on s390x, if they are necessary, run them here
|
||||
env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
@ -230,7 +226,6 @@ jobs:
|
||||
else
|
||||
JENKINS_USER="--user jenkins"
|
||||
USED_IMAGE="${DOCKER_IMAGE}"
|
||||
DOCKER_SHELL_CMD=
|
||||
fi
|
||||
|
||||
# Leaving 1GB for the runner and other things
|
||||
@ -240,7 +235,7 @@ jobs:
|
||||
TOTAL_MEMORY_WITH_SWAP=$(("${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}" + 3))
|
||||
|
||||
# detached container should get cleaned up by teardown_ec2_linux
|
||||
# Used for JENKINS_USER and DOCKER_SHELL_CMD, which can be empty
|
||||
# Used for JENKINS_USER, which can be empty
|
||||
# shellcheck disable=SC2086
|
||||
container_name=$(docker run \
|
||||
-e BUILD_ENVIRONMENT \
|
||||
@ -271,8 +266,7 @@ jobs:
|
||||
${JENKINS_USER} \
|
||||
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
|
||||
-w /var/lib/jenkins/workspace \
|
||||
"${USED_IMAGE}" \
|
||||
${DOCKER_SHELL_CMD}
|
||||
"${USED_IMAGE}"
|
||||
)
|
||||
docker exec -t "${container_name}" sh -c '.ci/pytorch/build.sh'
|
||||
|
||||
@ -338,5 +332,6 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
# on s390x stop the container for clean worker stop
|
||||
docker stop -a || true
|
||||
docker kill -a || true
|
||||
# ignore expansion of "docker ps -q" since it could be empty
|
||||
# shellcheck disable=SC2046
|
||||
docker stop $(docker ps -q) || true
|
||||
|
||||
41
.github/workflows/_linux-test.yml
vendored
41
.github/workflows/_linux-test.yml
vendored
@ -81,7 +81,7 @@ jobs:
|
||||
steps:
|
||||
- name: Setup SSH (Click me for login details)
|
||||
uses: pytorch/test-infra/.github/actions/setup-ssh@main
|
||||
if: ${{ !contains(matrix.runner, 'gcp.a100') && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
if: ${{ !contains(matrix.runner, 'gcp.a100') }}
|
||||
with:
|
||||
github-secret: ${{ secrets.GITHUB_TOKEN }}
|
||||
instructions: |
|
||||
@ -95,10 +95,9 @@ jobs:
|
||||
|
||||
- name: Setup Linux
|
||||
uses: ./.github/actions/setup-linux
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel'
|
||||
|
||||
- name: configure aws credentials
|
||||
if : ${{ inputs.aws-role-to-assume != '' && inputs.build-environment != 'linux-s390x-binary-manywheel' }}
|
||||
if : ${{ inputs.aws-role-to-assume != '' }}
|
||||
uses: aws-actions/configure-aws-credentials@v3
|
||||
with:
|
||||
role-to-assume: ${{ inputs.aws-role-to-assume }}
|
||||
@ -108,13 +107,11 @@ jobs:
|
||||
- name: Calculate docker image
|
||||
id: calculate-docker-image
|
||||
uses: pytorch/test-infra/.github/actions/calculate-docker-image@main
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel'
|
||||
with:
|
||||
docker-image-name: ${{ inputs.docker-image }}
|
||||
|
||||
- name: Use following to pull public copy of the image
|
||||
id: print-ghcr-mirror
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel'
|
||||
env:
|
||||
ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
shell: bash
|
||||
@ -124,7 +121,6 @@ jobs:
|
||||
|
||||
- name: Pull docker image
|
||||
uses: pytorch/test-infra/.github/actions/pull-docker-image@main
|
||||
if: inputs.build-environment != 'linux-s390x-binary-manywheel'
|
||||
with:
|
||||
docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }}
|
||||
|
||||
@ -170,7 +166,6 @@ jobs:
|
||||
with:
|
||||
name: ${{ inputs.build-environment }}
|
||||
s3-bucket: ${{ inputs.s3-bucket }}
|
||||
use-gha: ${{ inputs.use-gha }}
|
||||
|
||||
- name: Download TD artifacts
|
||||
continue-on-error: true
|
||||
@ -267,21 +262,9 @@ jobs:
|
||||
# comes from https://github.com/pytorch/test-infra/pull/6058
|
||||
TOTAL_MEMORY_WITH_SWAP=$(("${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}" + 3))
|
||||
|
||||
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then
|
||||
SHM_OPTS=
|
||||
JENKINS_USER=
|
||||
|
||||
# since some steps are skipped on s390x, if they are necessary, run them here
|
||||
env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}"
|
||||
else
|
||||
SHM_OPTS="--shm-size=${SHM_SIZE}"
|
||||
JENKINS_USER="--user jenkins"
|
||||
fi
|
||||
|
||||
# detached container should get cleaned up by teardown_ec2_linux
|
||||
# TODO: Stop building test binaries as part of the build phase
|
||||
# Used for GPU_FLAG, SHM_OPTS and JENKINS_USER since that doesn't play nice
|
||||
# Used for GPU_FLAG since that doesn't play nice
|
||||
# shellcheck disable=SC2086,SC2090
|
||||
container_name=$(docker run \
|
||||
${GPU_FLAG:-} \
|
||||
@ -333,11 +316,11 @@ jobs:
|
||||
--security-opt seccomp=unconfined \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--ipc=host \
|
||||
${SHM_OPTS} \
|
||||
--shm-size="${SHM_SIZE}" \
|
||||
--tty \
|
||||
--detach \
|
||||
--name="${container_name}" \
|
||||
${JENKINS_USER} \
|
||||
--user jenkins \
|
||||
-v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \
|
||||
-w /var/lib/jenkins/workspace \
|
||||
"${DOCKER_IMAGE}"
|
||||
@ -345,11 +328,6 @@ jobs:
|
||||
# Propagate download.pytorch.org IP to container
|
||||
grep download.pytorch.org /etc/hosts | docker exec -i "${container_name}" sudo bash -c "/bin/cat >> /etc/hosts"
|
||||
echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}"
|
||||
|
||||
if [[ ${BUILD_ENVIRONMENT} == *"s390x"* ]]; then
|
||||
docker exec -t "${container_name}" sh -c "python3 -m pip install -r .ci/docker/requirements-ci.txt"
|
||||
fi
|
||||
|
||||
docker exec -t "${container_name}" sh -c "python3 -m pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}"
|
||||
|
||||
- name: Upload pytest cache if tests failed
|
||||
@ -489,12 +467,3 @@ jobs:
|
||||
echo "NVIDIA driver detects $GPU_COUNT GPUs. The runner has a broken GPU, shutting it down..."
|
||||
.github/scripts/stop_runner_service.sh
|
||||
fi
|
||||
|
||||
- name: Cleanup docker
|
||||
if: always() && inputs.build-environment == 'linux-s390x-binary-manywheel'
|
||||
shell: bash
|
||||
run: |
|
||||
# on s390x stop the container for clean worker stop
|
||||
# ignore expansion of "docker ps -q" since it could be empty
|
||||
# shellcheck disable=SC2046
|
||||
docker stop $(docker ps -q) || true
|
||||
|
||||
1
.github/workflows/_mac-test-mps.yml
vendored
1
.github/workflows/_mac-test-mps.yml
vendored
@ -152,7 +152,6 @@ jobs:
|
||||
set -e
|
||||
|
||||
${CONDA_RUN} python3 test/run_test.py --mps --verbose
|
||||
MTL_CAPTURE_ENABLED=1 ${CONDA_RUN} python3 test/test_mps.py --verbose -k test_metal_capture
|
||||
|
||||
- name: Print remaining test logs
|
||||
shell: bash
|
||||
|
||||
16
.github/workflows/inductor-perf-test-nightly.yml
vendored
16
.github/workflows/inductor-perf-test-nightly.yml
vendored
@ -28,11 +28,6 @@ on:
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cppwrapper:
|
||||
description: Run inductor_cpp_wrapper?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
cudagraphs:
|
||||
description: Run inductor_cudagraphs?
|
||||
required: false
|
||||
@ -43,6 +38,11 @@ on:
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
freeze_autotune_cudagraphs:
|
||||
description: Run inductor_cudagraphs with freezing and max autotune for inference?
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
aotinductor:
|
||||
description: Run aot_inductor for inference?
|
||||
required: false
|
||||
@ -111,7 +111,7 @@ jobs:
|
||||
if: github.event.schedule == '0 7 * * 1-6'
|
||||
with:
|
||||
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
use-gha: anything-non-empty-to-use-gha
|
||||
@ -127,7 +127,7 @@ jobs:
|
||||
if: github.event.schedule == '0 7 * * 0'
|
||||
with:
|
||||
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
|
||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-cudagraphs_low_precision-true
|
||||
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
use-gha: anything-non-empty-to-use-gha
|
||||
@ -143,7 +143,7 @@ jobs:
|
||||
if: github.event_name == 'workflow_dispatch'
|
||||
with:
|
||||
build-environment: linux-focal-cuda12.4-py3.10-gcc9-sm80
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-false-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||
docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||
use-gha: anything-non-empty-to-use-gha
|
||||
|
||||
18
.github/workflows/inductor-rocm.yml
vendored
18
.github/workflows/inductor-rocm.yml
vendored
@ -29,13 +29,13 @@ jobs:
|
||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
||||
curr_ref_type: ${{ github.ref_type }}
|
||||
|
||||
linux-focal-rocm6_3-py3_10-inductor-build:
|
||||
name: rocm6.3-py3.10-inductor
|
||||
linux-focal-rocm6_2-py3_10-inductor-build:
|
||||
name: rocm6.2-py3.10-inductor
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image-name: pytorch-linux-focal-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -44,15 +44,15 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-inductor-test:
|
||||
linux-focal-rocm6_2-py3_10-inductor-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: rocm6.3-py3.10-inductor
|
||||
name: rocm6.2-py3.10-inductor
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs: linux-focal-rocm6_3-py3_10-inductor-build
|
||||
needs: linux-focal-rocm6_2-py3_10-inductor-build
|
||||
with:
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.test-matrix }}
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-inductor-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-inductor-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
18
.github/workflows/periodic.yml
vendored
18
.github/workflows/periodic.yml
vendored
@ -139,13 +139,13 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-debug-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-build:
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
linux-focal-rocm6_2-py3_10-build:
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image-name: pytorch-linux-focal-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -155,19 +155,19 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-test:
|
||||
linux-focal-rocm6_2-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-focal-rocm6_3-py3_10-build
|
||||
- linux-focal-rocm6_2-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }}
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build:
|
||||
|
||||
6
.github/workflows/pull.yml
vendored
6
.github/workflows/pull.yml
vendored
@ -411,15 +411,15 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-build:
|
||||
linux-focal-rocm6_2-py3_10-build:
|
||||
# don't run build twice on main
|
||||
if: github.event_name == 'pull_request'
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image-name: pytorch-linux-focal-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
|
||||
30
.github/workflows/rocm.yml
vendored
30
.github/workflows/rocm.yml
vendored
@ -26,36 +26,36 @@ jobs:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
linux-focal-rocm6_3-py3_10-build:
|
||||
linux-focal-rocm6_2-py3_10-build:
|
||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image-name: pytorch-linux-focal-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
|
||||
{ config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.2" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-test:
|
||||
linux-focal-rocm6_2-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-focal-rocm6_3-py3_10-build
|
||||
- linux-focal-rocm6_2-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }}
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
77
.github/workflows/s390x-periodic.yml
vendored
77
.github/workflows/s390x-periodic.yml
vendored
@ -1,77 +0,0 @@
|
||||
name: s390x-periodic
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs.
|
||||
# Also run less frequently on weekends.
|
||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||
push:
|
||||
tags:
|
||||
- ciflow/periodic/*
|
||||
- ciflow/s390/*
|
||||
branches:
|
||||
- release/*
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions: read-all
|
||||
|
||||
jobs:
|
||||
llm-td:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: before-test
|
||||
uses: ./.github/workflows/llm_td_retrieval.yml
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
target-determination:
|
||||
name: before-test
|
||||
uses: ./.github/workflows/target_determination.yml
|
||||
needs: llm-td
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
linux-manylinux-2_28-py3-cpu-s390x-build:
|
||||
if: github.repository_owner == 'pytorch'
|
||||
name: linux-manylinux-2_28-py3-cpu-s390x
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
with:
|
||||
build-environment: linux-s390x-binary-manywheel
|
||||
docker-image-name: pytorch/manylinuxs390x-builder:cpu-s390x-main
|
||||
runner: linux.s390x
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 2, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 3, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 4, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 5, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 6, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 7, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 8, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 9, num_shards: 10, runner: "linux.s390x" },
|
||||
{ config: "default", shard: 10, num_shards: 10, runner: "linux.s390x" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-manylinux-2_28-py3-cpu-s390x-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-manylinux-2_28-py3-cpu-s390x
|
||||
uses: ./.github/workflows/_linux-test.yml
|
||||
needs:
|
||||
- linux-manylinux-2_28-py3-cpu-s390x-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-s390x-binary-manywheel
|
||||
docker-image: pytorch/manylinuxs390x-builder:cpu-s390x-main
|
||||
test-matrix: ${{ needs.linux-manylinux-2_28-py3-cpu-s390x-build.outputs.test-matrix }}
|
||||
timeout-minutes: 480
|
||||
use-gha: "yes"
|
||||
secrets: inherit
|
||||
18
.github/workflows/slow.yml
vendored
18
.github/workflows/slow.yml
vendored
@ -103,13 +103,13 @@ jobs:
|
||||
test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-build:
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
linux-focal-rocm6_2-py3_10-build:
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image-name: pytorch-linux-focal-rocm-n-py3
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
@ -118,19 +118,19 @@ jobs:
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-test:
|
||||
linux-focal-rocm6_2-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-focal-rocm6_3-py3_10-build
|
||||
- linux-focal-rocm6_2-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }}
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.test-matrix }}
|
||||
secrets: inherit
|
||||
|
||||
linux-jammy-py3_10-clang15-asan-build:
|
||||
|
||||
24
.github/workflows/trunk.yml
vendored
24
.github/workflows/trunk.yml
vendored
@ -164,36 +164,36 @@ jobs:
|
||||
runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral"
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-build:
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
linux-focal-rocm6_2-py3_10-build:
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_linux-build.yml
|
||||
needs: get-label-type
|
||||
with:
|
||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image-name: pytorch-linux-focal-rocm-n-py3
|
||||
sync-tag: rocm-build
|
||||
test-matrix: |
|
||||
{ include: [
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.2' || 'linux.rocm.gpu.2' }}" },
|
||||
{ config: "distributed", shard: 1, num_shards: 1, runner: "${{ github.event_name == 'push' && github.event.ref == 'refs/heads/main' && 'linux.rocm.gpu.mi300.4' || 'linux.rocm.gpu.4' }}" },
|
||||
{ config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2" },
|
||||
{ config: "distributed", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.4" },
|
||||
]}
|
||||
secrets: inherit
|
||||
|
||||
linux-focal-rocm6_3-py3_10-test:
|
||||
linux-focal-rocm6_2-py3_10-test:
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
name: linux-focal-rocm6.3-py3.10
|
||||
name: linux-focal-rocm6.2-py3.10
|
||||
uses: ./.github/workflows/_rocm-test.yml
|
||||
needs:
|
||||
- linux-focal-rocm6_3-py3_10-build
|
||||
- linux-focal-rocm6_2-py3_10-build
|
||||
- target-determination
|
||||
with:
|
||||
build-environment: linux-focal-rocm6.3-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }}
|
||||
build-environment: linux-focal-rocm6.2-py3.10
|
||||
docker-image: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.docker-image }}
|
||||
test-matrix: ${{ needs.linux-focal-rocm6_2-py3_10-build.outputs.test-matrix }}
|
||||
tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl"
|
||||
secrets: inherit
|
||||
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -134,3 +134,6 @@
|
||||
[submodule "third_party/kleidiai"]
|
||||
path = third_party/kleidiai
|
||||
url = https://git.gitlab.arm.com/kleidi/kleidiai.git
|
||||
[submodule "third_party/flash-attention"]
|
||||
path = third_party/flash-attention
|
||||
url = https://github.com/drisspg/flash-attention.git
|
||||
|
||||
43
BUILD.bazel
43
BUILD.bazel
@ -38,29 +38,26 @@ aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/s
|
||||
|
||||
generated_cpu_cpp = [
|
||||
"aten/src/ATen/RegisterBackendSelect.cpp",
|
||||
"aten/src/ATen/RegisterCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterCPU_1.cpp",
|
||||
"aten/src/ATen/RegisterCPU_2.cpp",
|
||||
"aten/src/ATen/RegisterCPU_3.cpp",
|
||||
"aten/src/ATen/RegisterCPU.cpp",
|
||||
"aten/src/ATen/RegisterFunctionalization_0.cpp",
|
||||
"aten/src/ATen/RegisterFunctionalization_1.cpp",
|
||||
"aten/src/ATen/RegisterFunctionalization_2.cpp",
|
||||
"aten/src/ATen/RegisterFunctionalization_3.cpp",
|
||||
# "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
|
||||
"aten/src/ATen/RegisterMkldnnCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCPU_0.cpp",
|
||||
"aten/src/ATen/RegisterZeroTensor_0.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutograd_0.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
|
||||
"aten/src/ATen/RegisterMeta_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseMeta_0.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedMeta_0.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorMeta_0.cpp",
|
||||
"aten/src/ATen/RegisterMkldnnCPU.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorCPU.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCPU.cpp",
|
||||
"aten/src/ATen/RegisterSparseCPU.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCPU.cpp",
|
||||
"aten/src/ATen/RegisterZeroTensor.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutograd.cpp",
|
||||
"aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutograd.cpp",
|
||||
"aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional.cpp",
|
||||
"aten/src/ATen/RegisterMeta.cpp",
|
||||
"aten/src/ATen/RegisterSparseMeta.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedMeta.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorMeta.cpp",
|
||||
"aten/src/ATen/RegisterSchema.cpp",
|
||||
"aten/src/ATen/CPUFunctions.h",
|
||||
"aten/src/ATen/CPUFunctions_inl.h",
|
||||
@ -100,11 +97,11 @@ generated_cpu_cpp = [
|
||||
generated_cuda_cpp = [
|
||||
"aten/src/ATen/CUDAFunctions.h",
|
||||
"aten/src/ATen/CUDAFunctions_inl.h",
|
||||
"aten/src/ATen/RegisterCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCUDA_0.cpp",
|
||||
"aten/src/ATen/RegisterCUDA.cpp",
|
||||
"aten/src/ATen/RegisterNestedTensorCUDA.cpp",
|
||||
"aten/src/ATen/RegisterQuantizedCUDA.cpp",
|
||||
"aten/src/ATen/RegisterSparseCUDA.cpp",
|
||||
"aten/src/ATen/RegisterSparseCsrCUDA.cpp",
|
||||
]
|
||||
|
||||
generate_aten(
|
||||
|
||||
@ -876,6 +876,10 @@ cmake_dependent_option(
|
||||
# feature by default We dont currently document this feature because we don't
|
||||
# Suspect users building from source will need this
|
||||
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
|
||||
add_definitions(-DFLASHATTENTION_DISABLE_SOFTCAP)
|
||||
add_definitions(-DFLASH_NAMESPACE=pytorch_flash)
|
||||
# See https://github.com/pytorch/pytorch/issues/121558 for details
|
||||
add_definitions(-DUNFUSE_FMA)
|
||||
|
||||
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
|
||||
# Eff Attention won't
|
||||
|
||||
@ -164,9 +164,18 @@ file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
|
||||
file(GLOB native_utils_cpp "native/utils/*.cpp")
|
||||
|
||||
# flash_attention sources
|
||||
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
|
||||
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
|
||||
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
|
||||
# list(APPEND flash_attention_cuda_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
|
||||
file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
|
||||
# list(APPEND flash_attention_cuda_cpp ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp)
|
||||
# Flash attention C++ sources
|
||||
file(GLOB flash_attention_cuda_cpp
|
||||
"${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp"
|
||||
"native/transformers/cuda/flash_attn/flash_api.cpp"
|
||||
)
|
||||
|
||||
# file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
|
||||
# file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
|
||||
# file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
|
||||
|
||||
# flash_attention hip sources
|
||||
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
|
||||
@ -481,9 +490,6 @@ if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE AND NOT (MSVC AND CMAKE_SYSTEM_PRO
|
||||
set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
||||
set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
||||
set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will be built." FORCE)
|
||||
if(WIN32)
|
||||
set(SLEEF_BUILD_WITH_LIBM OFF CACHE BOOL "Don't build sleef with libm for Windows." FORCE)
|
||||
endif()
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
|
||||
|
||||
@ -394,6 +394,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
|
||||
rocm_fa_preferred_backend = b;
|
||||
}
|
||||
|
||||
|
||||
bool Context::allowFP16ReductionCuBLAS() const {
|
||||
return allow_fp16_reduction_cublas;
|
||||
}
|
||||
@ -410,14 +411,6 @@ void Context::setAllowBF16ReductionCuBLAS(bool b) {
|
||||
allow_bf16_reduction_cublas = b;
|
||||
}
|
||||
|
||||
bool Context::allowFP16AccumulationCuBLAS() const {
|
||||
return allow_fp16_accumulation_cublas;
|
||||
}
|
||||
|
||||
void Context::setAllowFP16AccumulationCuBLAS(bool b) {
|
||||
allow_fp16_accumulation_cublas = b;
|
||||
}
|
||||
|
||||
|
||||
bool Context::hasMKL() {
|
||||
#if AT_MKL_ENABLED()
|
||||
|
||||
@ -337,8 +337,6 @@ class TORCH_API Context {
|
||||
void setAllowFP16ReductionCuBLAS(bool);
|
||||
bool allowBF16ReductionCuBLAS() const;
|
||||
void setAllowBF16ReductionCuBLAS(bool);
|
||||
bool allowFP16AccumulationCuBLAS() const;
|
||||
void setAllowFP16AccumulationCuBLAS(bool);
|
||||
at::QEngine qEngine() const;
|
||||
void setQEngine(at::QEngine e);
|
||||
static const std::vector<at::QEngine>& supportedQEngines();
|
||||
@ -420,7 +418,6 @@ class TORCH_API Context {
|
||||
bool allow_tf32_cudnn = true;
|
||||
bool allow_fp16_reduction_cublas = true;
|
||||
bool allow_bf16_reduction_cublas = true;
|
||||
bool allow_fp16_accumulation_cublas = false;
|
||||
bool enabled_mkldnn = true;
|
||||
bool enabled_nnpack = true;
|
||||
at::LinalgBackend linalg_preferred_backend =
|
||||
|
||||
@ -54,7 +54,6 @@ bool isAccelerator(c10::DeviceType device_type) {
|
||||
}
|
||||
}
|
||||
|
||||
// NOLINTBEGIN(bugprone-unchecked-optional-access)
|
||||
c10::DeviceIndex deviceCount() {
|
||||
const auto device_type = getAccelerator(false);
|
||||
if (!device_type.has_value()) {
|
||||
@ -100,6 +99,5 @@ void synchronizeDevice(c10::DeviceIndex device_index) {
|
||||
// impl.synchronizeDevice should can be safely called from any device
|
||||
impl.synchronizeDevice(device_index);
|
||||
}
|
||||
// NOLINTEND(bugprone-unchecked-optional-access)
|
||||
|
||||
} // namespace at::accelerator
|
||||
|
||||
@ -86,14 +86,14 @@ TaskThreadPoolBase& _get_intraop_pool() {
|
||||
#endif // C10_MOBILE
|
||||
|
||||
// Run lambda function `fn` over `task_id` in [0, `range`) with threadpool.
|
||||
// `fn` will be called with params: task_id.
|
||||
static void _run_with_pool(const std::function<void(size_t)>& fn, size_t range) {
|
||||
// `fn` will be called with params: (thread_pool_task_id, task_id).
|
||||
void _run_with_pool(const std::function<void(int, size_t)>& fn, size_t range) {
|
||||
#ifndef C10_MOBILE
|
||||
for (const auto i : c10::irange(1, range)) {
|
||||
_get_intraop_pool().run([fn, i]() { fn(i); });
|
||||
_get_intraop_pool().run([fn, i]() { fn((int)i, i); });
|
||||
}
|
||||
// Run the first task on the current thread directly.
|
||||
fn(0);
|
||||
fn(0, 0);
|
||||
#else
|
||||
caffe2::PThreadPool* const pool = caffe2::pthreadpool();
|
||||
TORCH_INTERNAL_ASSERT(pool, "Invalid thread pool!");
|
||||
@ -102,7 +102,7 @@ static void _run_with_pool(const std::function<void(size_t)>& fn, size_t range)
|
||||
// PThreadPool::run() is blocking. A std::function [const] reference to
|
||||
// this lambda cannot go out of scope before PThreadPool::run() returns.
|
||||
[&fn](const size_t task_id) {
|
||||
fn(task_id);
|
||||
fn(0 /* unused */, task_id);
|
||||
}, range);
|
||||
#endif // C10_MOBILE
|
||||
}
|
||||
@ -113,10 +113,6 @@ struct ParallelRegionGuard {
|
||||
internal::set_thread_num(task_id);
|
||||
_set_in_parallel_region(true);
|
||||
}
|
||||
ParallelRegionGuard(const ParallelRegionGuard&) = delete;
|
||||
ParallelRegionGuard(ParallelRegionGuard&&) = delete;
|
||||
ParallelRegionGuard& operator=(const ParallelRegionGuard&) = delete;
|
||||
ParallelRegionGuard& operator=(ParallelRegionGuard&&) = delete;
|
||||
|
||||
~ParallelRegionGuard() {
|
||||
_set_in_parallel_region(false);
|
||||
@ -128,16 +124,16 @@ struct ParallelRegionGuard {
|
||||
|
||||
namespace internal {
|
||||
|
||||
static std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
|
||||
inline std::tuple<size_t, size_t> calc_num_tasks_and_chunk_size(
|
||||
int64_t begin, int64_t end, int64_t grain_size) {
|
||||
if ((end - begin) < grain_size) {
|
||||
return std::make_tuple(1, std::max((int64_t)0, end - begin));
|
||||
}
|
||||
// Choose number of tasks based on grain size and number of threads.
|
||||
int64_t chunk_size = divup((end - begin), get_num_threads());
|
||||
size_t chunk_size = divup((end - begin), get_num_threads());
|
||||
// Make sure each task is at least grain_size size.
|
||||
chunk_size = std::max(grain_size, chunk_size);
|
||||
size_t num_tasks = static_cast<size_t>(divup((end - begin), chunk_size));
|
||||
chunk_size = std::max((size_t)grain_size, chunk_size);
|
||||
size_t num_tasks = divup((end - begin), chunk_size);
|
||||
return std::make_tuple(num_tasks, chunk_size);
|
||||
}
|
||||
|
||||
@ -161,12 +157,12 @@ void invoke_parallel(
|
||||
} state;
|
||||
|
||||
auto task = [f, &state, begin, end, chunk_size]
|
||||
(size_t task_id) {
|
||||
int64_t local_start = static_cast<int64_t>(begin + task_id * chunk_size);
|
||||
(int /* unused */, size_t task_id) {
|
||||
int64_t local_start = begin + task_id * chunk_size;
|
||||
if (local_start < end) {
|
||||
int64_t local_end = std::min(end, static_cast<int64_t>(chunk_size + local_start));
|
||||
int64_t local_end = std::min(end, (int64_t)(chunk_size + local_start));
|
||||
try {
|
||||
ParallelRegionGuard guard(static_cast<int>(task_id));
|
||||
ParallelRegionGuard guard(task_id);
|
||||
f(local_start, local_end);
|
||||
} catch (...) {
|
||||
if (!state.err_flag.test_and_set()) {
|
||||
|
||||
@ -656,11 +656,10 @@ struct TORCH_API TensorType : public SharedType {
|
||||
const auto& shape = sizes();
|
||||
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
auto const &s = shape[i];
|
||||
if (!s.has_value()) {
|
||||
if (!shape[i].has_value()) {
|
||||
return std::optional<size_t>{};
|
||||
}
|
||||
prod *= s.value();
|
||||
prod *= shape[i].value();
|
||||
}
|
||||
return prod;
|
||||
}
|
||||
@ -728,11 +727,10 @@ struct TORCH_API TensorType : public SharedType {
|
||||
|
||||
TensorTypePtr contiguous() const {
|
||||
auto cloned = clone();
|
||||
auto concrete_sizes = sizes().concrete_sizes();
|
||||
TORCH_INTERNAL_ASSERT(concrete_sizes.has_value());
|
||||
TORCH_INTERNAL_ASSERT(sizes().concrete_sizes().has_value());
|
||||
auto strides = computeStrideProps(
|
||||
*concrete_sizes,
|
||||
contiguousStridesOf(*concrete_sizes));
|
||||
*sizes().concrete_sizes(),
|
||||
contiguousStridesOf(*sizes().concrete_sizes()));
|
||||
cloned->strides_ = strides;
|
||||
return cloned;
|
||||
}
|
||||
@ -1518,8 +1516,8 @@ struct TORCH_API FunctionType : public NamedType {
|
||||
FunctionType(torch::jit::Function* function);
|
||||
std::string annotation_str_impl(
|
||||
[[maybe_unused]] const TypePrinter& printer = nullptr) const override {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
return name()->qualifiedName();
|
||||
const auto& n = name().value();
|
||||
return n.qualifiedName();
|
||||
}
|
||||
torch::jit::Function* function_;
|
||||
};
|
||||
@ -2135,7 +2133,6 @@ struct MatchTypeReturn {
|
||||
return !reason_.has_value();
|
||||
}
|
||||
const std::string& reason() const {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
return reason_.value();
|
||||
}
|
||||
|
||||
@ -2184,7 +2181,6 @@ struct TORCH_API InterfaceType : public NamedType {
|
||||
}
|
||||
|
||||
std::string str() const override {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
return std::string("InterfaceType<") + name()->name() + ">";
|
||||
}
|
||||
|
||||
@ -2212,7 +2208,6 @@ struct TORCH_API InterfaceType : public NamedType {
|
||||
|
||||
std::string annotation_str_impl(
|
||||
[[maybe_unused]] const TypePrinter& printer = nullptr) const override {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
return name()->qualifiedName();
|
||||
}
|
||||
|
||||
|
||||
@ -904,7 +904,6 @@ bool ListType::isSubtypeOfExt(const Type& rhs_, std::ostream* why_not) const {
|
||||
std::string TupleType::str() const {
|
||||
std::stringstream ss;
|
||||
if (schema_ && name().has_value()) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
ss << name()->qualifiedName();
|
||||
} else {
|
||||
ss << "(";
|
||||
|
||||
@ -526,7 +526,12 @@ Vectorized<c10::BFloat16> inline fmadd(
|
||||
// elements, not the bottom and top half, so they don't seem
|
||||
// particularly useful here. Ideally we would include dot product in
|
||||
// the Vectorized interface...
|
||||
return a * b + c;
|
||||
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
|
||||
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
|
||||
const auto [c_float_low, c_float_high] = convert_bfloat16_float(c);
|
||||
return convert_float_bfloat16(
|
||||
fmadd(a_float_low, b_float_low, c_float_low),
|
||||
fmadd(a_float_high, b_float_high, c_float_high));
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -535,7 +540,12 @@ Vectorized<c10::BFloat16> inline fmsub(
|
||||
const Vectorized<c10::BFloat16>& b,
|
||||
const Vectorized<c10::BFloat16>& c) {
|
||||
// See NOTE [BF16 FMA] above.
|
||||
return a * b - c;
|
||||
const auto [a_float_low, a_float_high] = convert_bfloat16_float(a);
|
||||
const auto [b_float_low, b_float_high] = convert_bfloat16_float(b);
|
||||
const auto [c_float_low, c_float_high] = convert_bfloat16_float(c);
|
||||
return convert_float_bfloat16(
|
||||
fmsub(a_float_low, b_float_low, c_float_low),
|
||||
fmsub(a_float_high, b_float_high, c_float_high));
|
||||
}
|
||||
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
@ -572,7 +572,12 @@ Vectorized<c10::Half> inline fmadd(
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
return Vectorized<c10::Half>(vfmaq_f16(c, a, b));
|
||||
#else
|
||||
return a * b + c;
|
||||
const auto [a_float_low, a_float_high] = convert_half_float(a);
|
||||
const auto [b_float_low, b_float_high] = convert_half_float(b);
|
||||
const auto [c_float_low, c_float_high] = convert_half_float(c);
|
||||
return convert_float_half(
|
||||
fmadd(a_float_low, b_float_low, c_float_low),
|
||||
fmadd(a_float_high, b_float_high, c_float_high));
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -584,7 +589,12 @@ Vectorized<c10::Half> inline fmsub(
|
||||
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
return Vectorized<c10::Half>(vfmsq_f16(c, a, b));
|
||||
#else
|
||||
return a * b - c;
|
||||
const auto [a_float_low, a_float_high] = convert_half_float(a);
|
||||
const auto [b_float_low, b_float_high] = convert_half_float(b);
|
||||
const auto [c_float_low, c_float_high] = convert_half_float(c);
|
||||
return convert_float_half(
|
||||
fmsub(a_float_low, b_float_low, c_float_low),
|
||||
fmsub(a_float_high, b_float_high, c_float_high));
|
||||
#endif
|
||||
}
|
||||
#endif // !defined(C10_MOBILE) && defined(__aarch64__)
|
||||
|
||||
@ -284,7 +284,6 @@ class CuBlasLtMatmulDescriptor : public CuBlasLtDescriptor<
|
||||
}
|
||||
template <typename T>
|
||||
inline void setAttribute(cublasLtMatmulDescAttributes_t attr, const T value) {
|
||||
// NOLINTNEXTLINE(bugprone-sizeof-expression)
|
||||
TORCH_CUDABLAS_CHECK(::cublasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value)));
|
||||
}
|
||||
};
|
||||
@ -332,12 +331,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
#endif
|
||||
void * alpha_ptr = α
|
||||
void * beta_ptr = β
|
||||
if constexpr (std::is_same_v<Dtype, double>) {
|
||||
abcType = CUDA_R_64F;
|
||||
computeType = CUBLAS_COMPUTE_64F;
|
||||
@ -354,16 +347,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
abcType = CUDA_C_32F;
|
||||
scaleType = CUDA_C_32F;
|
||||
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
||||
#ifndef USE_ROCM
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
halpha = alpha;
|
||||
hbeta = beta;
|
||||
alpha_ptr = &halpha;
|
||||
beta_ptr = &hbeta;
|
||||
}
|
||||
#endif
|
||||
abcType = CUDA_R_16F;
|
||||
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
|
||||
abcType = CUDA_R_16BF;
|
||||
@ -409,7 +392,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
|
||||
#endif
|
||||
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
@ -431,12 +414,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
cublasStatus_t cublasStatus = cublasLtMatmul(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
alpha_ptr,
|
||||
&alpha,
|
||||
a,
|
||||
Adesc.descriptor(),
|
||||
b,
|
||||
Bdesc.descriptor(),
|
||||
beta_ptr,
|
||||
&beta,
|
||||
c,
|
||||
Cdesc.descriptor(),
|
||||
c,
|
||||
@ -546,11 +529,6 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
BGEMM_CHECK_ARGVALUES(at::Half);
|
||||
float falpha = alpha;
|
||||
float fbeta = beta;
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
void * alpha_ptr = &falpha;
|
||||
void * beta_ptr = &fbeta;
|
||||
auto compute_type = CUDA_R_32F;
|
||||
#ifdef USE_ROCM
|
||||
int flag = 0;
|
||||
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
|
||||
@ -567,20 +545,13 @@ void bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
|
||||
0, flag)));
|
||||
#else
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
halpha = alpha;
|
||||
hbeta = beta;
|
||||
compute_type = CUDA_R_16F;
|
||||
alpha_ptr = &halpha;
|
||||
beta_ptr = &hbeta;
|
||||
}
|
||||
if (prop->major >= 5){
|
||||
TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(
|
||||
handle, opa, opb, m, n, k,
|
||||
alpha_ptr, a, CUDA_R_16F, lda, stridea,
|
||||
b, CUDA_R_16F, ldb, strideb, beta_ptr,
|
||||
(void*)(&falpha), a, CUDA_R_16F, lda, stridea,
|
||||
b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
|
||||
c, CUDA_R_16F, ldc, stridec,
|
||||
num_batches, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
} else {
|
||||
for (const auto i : c10::irange(num_batches)) {
|
||||
at::cuda::blas::gemm<at::Half>(
|
||||
@ -895,13 +866,8 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
cublasOperation_t opb = _cublasOpFromChar(transb);
|
||||
float falpha = alpha;
|
||||
float fbeta = beta;
|
||||
at::Half halpha;
|
||||
at::Half hbeta;
|
||||
void * alpha_ptr = &falpha;
|
||||
void * beta_ptr = &fbeta;
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(at::Half);
|
||||
auto compute_type = CUDA_R_32F;
|
||||
#ifdef USE_ROCM
|
||||
int flag = 0;
|
||||
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
|
||||
@ -934,18 +900,13 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
flag)));
|
||||
#else
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
compute_type = CUDA_R_16F;
|
||||
halpha = alpha;
|
||||
hbeta = beta;
|
||||
alpha_ptr = &halpha;
|
||||
beta_ptr = &hbeta;
|
||||
}
|
||||
if (prop->major >= 5) {
|
||||
#ifndef USE_ROCM
|
||||
cublasMath_t cublas_flags = CUBLAS_DEFAULT_MATH;
|
||||
if (!at::globalContext().allowFP16ReductionCuBLAS()) {
|
||||
cublas_flags = static_cast<cublasMath_t>(cublas_flags | CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION);
|
||||
}
|
||||
#endif
|
||||
// Disallow fp16 reductions that could lead to unexpected overflow issues.
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, cublas_flags));
|
||||
TORCH_CUDABLAS_CHECK(cublasGemmEx(
|
||||
@ -955,18 +916,18 @@ void gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
alpha_ptr,
|
||||
&falpha,
|
||||
a,
|
||||
CUDA_R_16F,
|
||||
lda,
|
||||
b,
|
||||
CUDA_R_16F,
|
||||
ldb,
|
||||
beta_ptr,
|
||||
&fbeta,
|
||||
c,
|
||||
CUDA_R_16F,
|
||||
ldc,
|
||||
compute_type,
|
||||
CUDA_R_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
|
||||
} else {
|
||||
@ -1268,12 +1229,6 @@ void gemm_and_bias(
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
|
||||
cudaDataType_t scaleType = CUDA_R_32F;
|
||||
void * alpha_ptr = &alpha_val;
|
||||
void * beta_ptr = &beta_val;
|
||||
#ifndef USE_ROCM
|
||||
at::Half halpha_val;
|
||||
at::Half hbeta_val;
|
||||
#endif
|
||||
if constexpr (std::is_same_v<Dtype, double>) {
|
||||
abcType = CUDA_R_64F;
|
||||
computeType = CUBLAS_COMPUTE_64F;
|
||||
@ -1284,17 +1239,6 @@ void gemm_and_bias(
|
||||
}
|
||||
abcType = CUDA_R_32F;
|
||||
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
|
||||
#ifndef USE_ROCM
|
||||
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
|
||||
if (prop->major >= 7 && at::globalContext().allowFP16AccumulationCuBLAS()) {
|
||||
computeType = CUBLAS_COMPUTE_16F;
|
||||
scaleType = CUDA_R_16F;
|
||||
halpha_val = alpha_val;
|
||||
hbeta_val = beta_val;
|
||||
alpha_ptr = &halpha_val;
|
||||
beta_ptr = &hbeta_val;
|
||||
}
|
||||
#endif
|
||||
abcType = CUDA_R_16F;
|
||||
} else if constexpr (std::is_same_v<Dtype, at::BFloat16>) {
|
||||
abcType = CUDA_R_16BF;
|
||||
@ -1340,7 +1284,7 @@ void gemm_and_bias(
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
|
||||
#endif
|
||||
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||
int returnedResult = 0;
|
||||
@ -1363,12 +1307,12 @@ void gemm_and_bias(
|
||||
cublasStatus_t cublasStatus = cublasLtMatmul(
|
||||
ltHandle,
|
||||
computeDesc.descriptor(),
|
||||
alpha_ptr,
|
||||
&alpha_val,
|
||||
mat1_ptr,
|
||||
Adesc.descriptor(),
|
||||
mat2_ptr,
|
||||
Bdesc.descriptor(),
|
||||
beta_ptr,
|
||||
&beta_val,
|
||||
result_ptr,
|
||||
Cdesc.descriptor(),
|
||||
result_ptr,
|
||||
@ -1522,7 +1466,7 @@ void scaled_gemm(
|
||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
|
||||
}
|
||||
size_t workspaceSize = _getWorkspaceSize();
|
||||
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
|
||||
|
||||
CuBlasLtMatmulPreference preference;
|
||||
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
|
||||
|
||||
@ -56,6 +56,7 @@ cusparseIndexType_t getCuSparseIndexType(const c10::ScalarType& scalar_type) {
|
||||
}
|
||||
}
|
||||
|
||||
#if AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||
cusparseDnMatDescr_t createRawDnMatDescriptor(const Tensor& input, int64_t batch_offset, bool is_const=false) {
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.layout() == kStrided);
|
||||
IntArrayRef input_strides = input.strides();
|
||||
@ -120,6 +121,7 @@ CuSparseDnMatDescriptor::CuSparseDnMatDescriptor(const Tensor& input, int64_t ba
|
||||
CuSparseConstDnMatDescriptor::CuSparseConstDnMatDescriptor(const Tensor& input, int64_t batch_offset) {
|
||||
descriptor_.reset(createRawDnMatDescriptor(input, batch_offset, /*is_const*/true));
|
||||
}
|
||||
#endif // AT_USE_CUSPARSE_GENERIC_API() || AT_USE_HIPSPARSE_GENERIC_API()
|
||||
|
||||
CuSparseDnVecDescriptor::CuSparseDnVecDescriptor(const Tensor& input) {
|
||||
// cuSPARSE doesn't support batched vectors
|
||||
|
||||
@ -116,7 +116,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||
|
||||
|
||||
virtual void recordMemoryHistory(
|
||||
const std::optional<std::string>& enabled,
|
||||
std::optional<std::string> enabled,
|
||||
const std::string& stacks,
|
||||
size_t max_entries) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
|
||||
@ -162,7 +162,6 @@ grid_sample_backward_helper_in(
|
||||
|
||||
static std::tuple<Tensor, std::optional<int64_t>, Tensor, std::optional<int64_t>>
|
||||
grid_sample_backward_helper_out(
|
||||
// NOLINTNEXTLINE(performance-unnecessary-value-param)
|
||||
std::tuple<Tensor, Tensor> bw_out,
|
||||
int64_t grad_input_out_bdim,
|
||||
int64_t grad_grid_out_bdim,
|
||||
@ -262,7 +261,7 @@ struct UpsampleBackwardBatchRuleHelper<F, Func, typelist<A, B, C, T...>> {
|
||||
|
||||
auto out = Func(
|
||||
std::move(grad_output_),
|
||||
output_size,
|
||||
std::move(output_size),
|
||||
std::move(physical_input_size),
|
||||
std::forward<T>(extra_args)...);
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
|
||||
@ -16,7 +16,6 @@
|
||||
// registered to FuncTorchVmapMode. This is because we need to interpose on
|
||||
// random operations even if they're not on a BatchedTensor.
|
||||
|
||||
// NOLINTBEGIN(bugprone-unchecked-optional-access)
|
||||
namespace at::functorch {
|
||||
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
@ -502,4 +501,3 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
|
||||
}
|
||||
|
||||
} // namespace at::functorch
|
||||
// NOLINTEND(bugprone-unchecked-optional-access)
|
||||
|
||||
@ -11,7 +11,6 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
// NOLINTBEGIN(bugprone-unchecked-optional-access)
|
||||
namespace at::functorch {
|
||||
|
||||
static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
|
||||
@ -511,4 +510,3 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
}
|
||||
|
||||
} // namespace at::functorch
|
||||
// NOLINTEND(bugprone-unchecked-optional-access)
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
|
||||
// NOLINTBEGIN(bugprone-unchecked-optional-access)
|
||||
namespace at::functorch {
|
||||
|
||||
namespace {
|
||||
@ -1284,4 +1283,3 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
||||
}
|
||||
|
||||
} // namespace at::functorch
|
||||
// NOLINTEND(bugprone-unchecked-optional-access)
|
||||
|
||||
@ -156,7 +156,6 @@ const Tensor& resize__plumbing(
|
||||
"resize_: batching rule only supports None or Contiguous MemoryFormat");
|
||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
vmap_check_escaped(maybe_layer, "resize__plumbing");
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
int64_t cur_level = maybe_layer->layerId();
|
||||
if (!isBatchedAtLevel(self, cur_level)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard2(DispatchKey::FuncTorchBatched);
|
||||
|
||||
@ -41,7 +41,6 @@ DynamicLayer::DynamicLayer(
|
||||
}
|
||||
switch (transform_type) {
|
||||
case TransformType::Vmap:
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
interpreter_ = Interpreter::Vmap(layerId, std::move(batchSize.value()), randomness.value());
|
||||
break;
|
||||
case TransformType::Grad:
|
||||
@ -51,7 +50,6 @@ DynamicLayer::DynamicLayer(
|
||||
interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value());
|
||||
break;
|
||||
case TransformType::Functionalize:
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
interpreter_ = Interpreter::Functionalize(layerId, functionalize_add_back_views.value());
|
||||
break;
|
||||
default:
|
||||
@ -347,7 +345,9 @@ void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int6
|
||||
if (!ivalue.isTensor()) {
|
||||
continue;
|
||||
}
|
||||
args[idx] = func(ivalue.toTensor(), flag);
|
||||
Tensor value = ivalue.toTensor();
|
||||
Tensor replacement = func(value, flag);
|
||||
args[idx] = std::move(replacement);
|
||||
// sanity checks
|
||||
if (ivalue.toTensor().defined()) {
|
||||
TORCH_INTERNAL_ASSERT(args[idx].toTensor().defined());
|
||||
|
||||
@ -118,7 +118,6 @@ static Tensor moveDimToFrontAndExpand(Tensor tensor, std::optional<int64_t> dim,
|
||||
// to `batch_sizes`
|
||||
VmapPhysicalViewVec
|
||||
MultiBatchVmapTransform::logicalToPhysical(ITensorListRef logical_tensors) {
|
||||
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
|
||||
auto cur_level = maybeCurrentDynamicLayer().value().layerId();
|
||||
c10::SymInt bdim_size = -1;
|
||||
|
||||
|
||||
@ -16,10 +16,6 @@
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#ifndef __OBJC__
|
||||
typedef void* MTLCaptureManager;
|
||||
#endif
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
namespace Profiler {
|
||||
@ -62,7 +58,24 @@ struct BaseInfo {
|
||||
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
|
||||
static std::string buildTensorString(
|
||||
const Tensor& tensor,
|
||||
bool includeBufferId = false);
|
||||
bool includeBufferId = false) {
|
||||
if (tensor.defined()) {
|
||||
std::stringstream tensorStr;
|
||||
auto deviceType = tensor.device().type();
|
||||
tensorStr << c10::DeviceTypeName(deviceType);
|
||||
// see comments for INCLUDE_BUFFER_ID
|
||||
if (includeBufferId && deviceType == at::kMPS) {
|
||||
id<MTLBuffer> buffer =
|
||||
__builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":"
|
||||
<< buffer.retainCount << ")";
|
||||
}
|
||||
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
|
||||
return tensorStr.str();
|
||||
} else {
|
||||
return "undefined";
|
||||
}
|
||||
}
|
||||
static uint64_t getTime() {
|
||||
return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
|
||||
}
|
||||
|
||||
@ -30,23 +30,6 @@ const std::string BaseInfo::toString(double gpuTime, double schedulingTime) cons
|
||||
schedulingTime > 0.0 ? fmt::format(", cpu={:.3f} ms", schedulingTime) : "");
|
||||
}
|
||||
|
||||
std::string BaseInfo::buildTensorString(const Tensor& tensor, bool includeBufferId) {
|
||||
if (tensor.defined()) {
|
||||
std::stringstream tensorStr;
|
||||
auto deviceType = tensor.device().type();
|
||||
tensorStr << c10::DeviceTypeName(deviceType);
|
||||
// see comments for INCLUDE_BUFFER_ID
|
||||
if (includeBufferId && deviceType == at::kMPS) {
|
||||
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":" << buffer.retainCount << ")";
|
||||
}
|
||||
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
|
||||
return tensorStr.str();
|
||||
} else {
|
||||
return "undefined";
|
||||
}
|
||||
}
|
||||
|
||||
const std::string OperationInfo::toString(double gpuTime, double schedulingTime) const {
|
||||
return fmt::format("aten::{} (id={}{}, run={}{})",
|
||||
strKey,
|
||||
|
||||
@ -15,26 +15,21 @@
|
||||
#include <Metal/Metal.h>
|
||||
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
||||
typedef MPSCommandBuffer* MPSCommandBuffer_t;
|
||||
typedef id<MTLCommandQueue> MTLCommandQueue_t;
|
||||
typedef id<MTLCommandBuffer> MTLCommandBuffer_t;
|
||||
typedef id<MTLComputeCommandEncoder> MTLComputeCommandEncoder_t;
|
||||
typedef id<MTLSharedEvent> MTLSharedEvent_t;
|
||||
typedef id<MTLDevice> MTLDevice_t;
|
||||
typedef id<MTLBuffer> MTLBuffer_t;
|
||||
#else
|
||||
#include <dispatch/dispatch.h>
|
||||
typedef void* MPSCommandBuffer_t;
|
||||
typedef void* MPSGraph;
|
||||
typedef void* MPSGraphExecutionDescriptor;
|
||||
typedef void* MPSGraphCompilationDescriptor;
|
||||
typedef void* MTLCommandQueue_t;
|
||||
typedef void* MTLCommandQueue;
|
||||
typedef void* MTLCommandBuffer_t;
|
||||
typedef void* MTLCommandBuffer;
|
||||
typedef void* MTLComputeCommandEncoder_t;
|
||||
typedef void* MTLSharedEvent_t;
|
||||
typedef void* dispatch_queue_t;
|
||||
typedef void* MTLDevice_t;
|
||||
typedef void* MTLBuffer_t;
|
||||
typedef void* MTLCommandBufferHandler;
|
||||
typedef void* NSDictionary;
|
||||
#define nil NULL
|
||||
#define nil NULL;
|
||||
#endif
|
||||
|
||||
namespace at::mps {
|
||||
@ -60,29 +55,27 @@ class TORCH_API MPSStream {
|
||||
explicit MPSStream(Stream stream);
|
||||
|
||||
~MPSStream();
|
||||
|
||||
MTLCommandQueue_t commandQueue() const {
|
||||
return _commandQueue;
|
||||
}
|
||||
|
||||
};
|
||||
dispatch_queue_t queue() const {
|
||||
return _serialQueue;
|
||||
}
|
||||
|
||||
MPSCommandBuffer_t commandBuffer();
|
||||
MPSCommandBuffer* commandBuffer();
|
||||
MTLComputeCommandEncoder_t commandEncoder();
|
||||
void endKernelCoalescing();
|
||||
void synchronize(SyncType syncType);
|
||||
void fill(MTLBuffer_t buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
|
||||
void copy(MTLBuffer_t srcBuffer,
|
||||
MTLBuffer_t dstBuffer,
|
||||
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
|
||||
void copy(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
uint64_t profileId,
|
||||
SyncType syncType = SyncType::NONE);
|
||||
void copy_and_sync(MTLBuffer_t srcBuffer,
|
||||
MTLBuffer_t dstBuffer,
|
||||
void copy_and_sync(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
@ -101,9 +94,11 @@ class TORCH_API MPSStream {
|
||||
|
||||
MTLCommandQueue_t stream() const {
|
||||
return _commandQueue;
|
||||
}
|
||||
};
|
||||
|
||||
MTLDevice_t device() const;
|
||||
MTLDevice_t device() const {
|
||||
return [_commandQueue device];
|
||||
}
|
||||
|
||||
/// Explicit conversion to Stream.
|
||||
Stream unwrap() const {
|
||||
@ -113,8 +108,8 @@ class TORCH_API MPSStream {
|
||||
private:
|
||||
Stream _stream;
|
||||
MTLCommandQueue_t _commandQueue = nil;
|
||||
MPSCommandBuffer_t _commandBuffer = nil;
|
||||
MPSCommandBuffer_t _prevCommandBuffer = nil;
|
||||
MPSCommandBuffer* _commandBuffer = nil;
|
||||
MPSCommandBuffer* _prevCommandBuffer = nil;
|
||||
MTLComputeCommandEncoder_t _commandEncoder = nil;
|
||||
MPSGraphExecutionDescriptor* _executionDescriptor = nil;
|
||||
MPSGraphCompilationDescriptor* _compilationDescriptor = nil;
|
||||
|
||||
@ -51,10 +51,6 @@ MPSCommandBuffer* MPSStream::commandBuffer() {
|
||||
return _commandBuffer;
|
||||
}
|
||||
|
||||
id<MTLDevice> MPSStream::device() const {
|
||||
return [_commandQueue device];
|
||||
}
|
||||
|
||||
id<MTLComputeCommandEncoder> MPSStream::commandEncoder() {
|
||||
if (!_commandEncoder) {
|
||||
_commandEncoder = [commandBuffer() computeCommandEncoder].retain;
|
||||
|
||||
@ -579,7 +579,7 @@ static void _rrelu_with_noise_train(
|
||||
Tensor& noise,
|
||||
const Scalar& lower_,
|
||||
const Scalar& upper_,
|
||||
const std::optional<Generator>& generator) {
|
||||
std::optional<Generator> generator) {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
opmath_t lower = lower_.to<opmath_t>();
|
||||
opmath_t upper = upper_.to<opmath_t>();
|
||||
|
||||
@ -946,8 +946,6 @@ inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) {
|
||||
return dnnl::memory::data_type::bf16;
|
||||
} else if (dtype == ScalarType::Half) {
|
||||
return dnnl::memory::data_type::f16;
|
||||
} else if (dtype == ScalarType::Int) {
|
||||
return dnnl::memory::data_type::s32;
|
||||
} else if (dtype == ScalarType::Byte) {
|
||||
return dnnl::memory::data_type::u8;
|
||||
} else if (dtype == ScalarType::Char) {
|
||||
@ -1093,7 +1091,7 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
int64_t(1),
|
||||
1,
|
||||
ld_a,
|
||||
ld_b,
|
||||
ld_c,
|
||||
@ -1133,12 +1131,6 @@ struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
|
||||
} else if (dtype == ScalarType::BFloat16) {
|
||||
static bool bf16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core;
|
||||
return bf16_support;
|
||||
} else if (dtype == ScalarType::Byte) {
|
||||
static bool u8_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx;
|
||||
return u8_support;
|
||||
} else if (dtype == ScalarType::Char) {
|
||||
static bool s8_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_vnni;
|
||||
return s8_support;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -1189,9 +1181,6 @@ struct Pack : public KernelCache <PackKey, pack_t> {
|
||||
} else if (dtype == ScalarType::BFloat16) {
|
||||
static bool bf16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx;
|
||||
return bf16_pack;
|
||||
} else if (dtype == ScalarType::Byte || dtype == ScalarType::Char) {
|
||||
static bool bit8_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx;
|
||||
return bit8_pack;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -1293,54 +1282,6 @@ void brgemm(
|
||||
beta, C, ld_c);
|
||||
}
|
||||
|
||||
void brgemm(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t ld_a,
|
||||
int64_t ld_b,
|
||||
int64_t ld_c,
|
||||
const bool add_C,
|
||||
const unsigned char* A,
|
||||
const unsigned char* B,
|
||||
int32_t* C,
|
||||
bool is_vnni) {
|
||||
#if defined(ONEDNN_UKERNEL_ENABLED)
|
||||
if (is_vnni && Brgemm::device_check(ScalarType::Byte)) {
|
||||
Brgemm::call<unsigned char, unsigned char, int32_t>(
|
||||
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// raise an error if the path is not supported
|
||||
TORCH_CHECK(false,
|
||||
"U8 Brgemm is only supported on X64 when oneDNN ukernel is enabled and `amx` is supported");
|
||||
}
|
||||
|
||||
void brgemm(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t ld_a,
|
||||
int64_t ld_b,
|
||||
int64_t ld_c,
|
||||
const bool add_C,
|
||||
const unsigned char* A,
|
||||
const signed char* B,
|
||||
int32_t* C,
|
||||
bool is_vnni) {
|
||||
#if defined(ONEDNN_UKERNEL_ENABLED)
|
||||
if (is_vnni && Brgemm::device_check(ScalarType::Char)) {
|
||||
Brgemm::call<unsigned char, signed char, int32_t>(
|
||||
M, N, K, ld_a, ld_b, ld_c, add_C, A, B, C);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
// raise an error if the path is not supported
|
||||
TORCH_CHECK(false,
|
||||
"I8 Brgemm is only supported on X64 when oneDNN ukernel is enabled and `amx` is supported");
|
||||
}
|
||||
|
||||
void brgemm_release(bool is_vnni) {
|
||||
#if defined(ONEDNN_UKERNEL_ENABLED)
|
||||
if (is_vnni) {
|
||||
|
||||
@ -233,37 +233,11 @@ TORCH_API void brgemm(
|
||||
float* C,
|
||||
bool is_vnni = false);
|
||||
|
||||
TORCH_API void brgemm(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t ld_a,
|
||||
int64_t ld_b,
|
||||
int64_t ld_c,
|
||||
const bool add_C,
|
||||
const unsigned char* A,
|
||||
const unsigned char* B,
|
||||
int32_t* C,
|
||||
bool is_vnni = true);
|
||||
|
||||
TORCH_API void brgemm(
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t ld_a,
|
||||
int64_t ld_b,
|
||||
int64_t ld_c,
|
||||
const bool add_C,
|
||||
const unsigned char* A,
|
||||
const signed char* B,
|
||||
int32_t* C,
|
||||
bool is_vnni = true);
|
||||
|
||||
// Release brgemm hardware context
|
||||
TORCH_API void brgemm_release(bool is_vnni = true);
|
||||
|
||||
// Pack B matrix to get better performance if needed
|
||||
TORCH_API void pack(
|
||||
void pack(
|
||||
int64_t K,
|
||||
int64_t N,
|
||||
int64_t ld_in,
|
||||
|
||||
@ -620,7 +620,11 @@ Tensor _conj_physical(const Tensor& self) {
|
||||
if (self.is_conj()) {
|
||||
return self.conj().clone();
|
||||
}
|
||||
auto result = at::empty_like(self);
|
||||
auto options = self.options();
|
||||
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
|
||||
options = options.dtype(c10::get_default_dtype());
|
||||
}
|
||||
auto result = at::empty_like(self, options);
|
||||
return at::conj_physical_out(result, self);
|
||||
}
|
||||
|
||||
|
||||
@ -1,21 +1,21 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <c10/util/BFloat16.h> // For c10::is_reduced_floating_point_v.
|
||||
#include <c10/util/BFloat16.h> // For std::is_reduced_floating_point_v.
|
||||
|
||||
namespace at::native {
|
||||
constexpr double kGeluBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
|
||||
constexpr double kGeluKappa = 0.044715;
|
||||
|
||||
template <typename T>
|
||||
using reduced_fp_to_float_t = std::conditional_t<c10::is_reduced_floating_point_v<T>, float, T>;
|
||||
using reduced_fp_to_float_t = std::conditional_t<std::is_reduced_floating_point_v<T>, float, T>;
|
||||
|
||||
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||
template <typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
|
||||
float reduced_fp_to_float(T x) {
|
||||
return float(x);
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||
template <typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
|
||||
T reduced_fp_to_float(T x) {
|
||||
return x;
|
||||
}
|
||||
@ -29,7 +29,7 @@ T scalar_gelu_approximated_with_tanh(T x) {
|
||||
return opmath_t(0.5) * x_float * (opmath_t(1) + std::tanh(inner));
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||
template <typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
|
||||
vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x) {
|
||||
const vec::Vectorized<T> kPointFiveVec(T(0.5));
|
||||
const vec::Vectorized<T> kOneVec(T(1));
|
||||
@ -40,7 +40,7 @@ vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x)
|
||||
return kPointFiveVec * x * (kOneVec + inner_vec.tanh());
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||
template <typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
|
||||
vec::Vectorized<T> vectorized_gelu_approximated_with_tanh(vec::Vectorized<T> x) {
|
||||
auto [x0, x1] = at::vec::convert_to_float<T>(x);
|
||||
return at::vec::convert_from_float<T>(
|
||||
@ -56,7 +56,7 @@ T scalar_gelu(T x) {
|
||||
return reduced_fp_to_float(x) * opmath_t(0.5) * (opmath_t(1) + std::erf(reduced_fp_to_float(x) * kAlpha));
|
||||
}
|
||||
|
||||
template<typename T, std::enable_if_t<!c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||
template<typename T, std::enable_if_t<!std::is_reduced_floating_point_v<T>, bool> = true>
|
||||
vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
|
||||
const vec::Vectorized<T> kAlphaVec(T(M_SQRT1_2));
|
||||
const vec::Vectorized<T> kOneVec(T(1));
|
||||
@ -64,7 +64,7 @@ vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
|
||||
return x * kPointFiveVec * (kOneVec + (x * kAlphaVec).erf());
|
||||
}
|
||||
|
||||
template<typename T, std::enable_if_t<c10::is_reduced_floating_point_v<T>, bool> = true>
|
||||
template<typename T, std::enable_if_t<std::is_reduced_floating_point_v<T>, bool> = true>
|
||||
vec::Vectorized<T> vectorized_gelu(vec::Vectorized<T> x) {
|
||||
auto [x0, x1] = at::vec::convert_to_float<T>(x);
|
||||
return at::vec::convert_from_float<T>(vectorized_gelu(x0), vectorized_gelu(x1));
|
||||
|
||||
@ -995,7 +995,7 @@ static void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
dst_f32 += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
size_t m,
|
||||
|
||||
@ -964,7 +964,7 @@ ScalingType get_scaling_type(
|
||||
} // namespace
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output matrices
|
||||
// Scales are only applicable when matrices are of Float8 type and assumed to be equal to 1.0 by default.
|
||||
// Scales are only applicable when matrices are of Float8 type and assumbed to be equal to 1.0 by default.
|
||||
// If output matrix type is 16 or 32-bit type, scale_result is not applied.
|
||||
// Known limitations:
|
||||
// - Only works if mat1 is row-major and mat2 is column-major
|
||||
|
||||
@ -19,7 +19,37 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
|
||||
#if defined(BUILD_ROWWISE_FP8_KERNEL)
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
// We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader
|
||||
static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
|
||||
CUtensorMap* tensorMap,
|
||||
CUtensorMapDataType tensorDataType,
|
||||
cuuint32_t tensorRank,
|
||||
void* globalAddress,
|
||||
const cuuint64_t* globalDim,
|
||||
const cuuint64_t* globalStrides,
|
||||
const cuuint32_t* boxDim,
|
||||
const cuuint32_t* elementStrides,
|
||||
CUtensorMapInterleave interleave,
|
||||
CUtensorMapSwizzle swizzle,
|
||||
CUtensorMapL2promotion l2Promotion,
|
||||
CUtensorMapFloatOOBfill oobFill) {
|
||||
return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
|
||||
tensorMap,
|
||||
tensorDataType,
|
||||
tensorRank,
|
||||
globalAddress,
|
||||
globalDim,
|
||||
globalStrides,
|
||||
boxDim,
|
||||
elementStrides,
|
||||
interleave,
|
||||
swizzle,
|
||||
l2Promotion,
|
||||
oobFill);
|
||||
}
|
||||
|
||||
|
||||
#include <cutlass/version.h>
|
||||
#include <cutlass/core_io.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/gemm/device/gemm.h>
|
||||
@ -27,7 +57,16 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/trace.h>
|
||||
#include <cutlass/util/host_tensor.h>
|
||||
#include <cutlass/version.h>
|
||||
|
||||
// Rename the global function symbol
|
||||
#if CUTLASS_VERSION == 351
|
||||
#include <cute/tensor.hpp>
|
||||
#else
|
||||
#define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
|
||||
#include <cute/tensor.hpp>
|
||||
#undef cuTensorMapEncodeTiled
|
||||
#endif
|
||||
// Set everything back to normal
|
||||
|
||||
#include <cutlass/gemm/collective/collective_builder.hpp>
|
||||
#include <cutlass/gemm/device/gemm_universal_adapter.h>
|
||||
@ -38,8 +77,6 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
|
||||
#include <cutlass/gemm/kernel/gemm_universal.hpp>
|
||||
#include <cutlass/util/packed_stride.hpp>
|
||||
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
||||
namespace {
|
||||
|
||||
@ -148,7 +185,11 @@ void f8f8bf16_rowwise_impl(
|
||||
|
||||
// Implement rowwise scaling epilogue.
|
||||
constexpr int ColBroadcastStages = 0;
|
||||
#if CUTLASS_VERSION == 351
|
||||
constexpr int RowBroadcastStages = 0;
|
||||
#else
|
||||
constexpr int RowBroadcastStages = PingPong::value ? 2 : 1;
|
||||
#endif
|
||||
|
||||
using XScale = cutlass::epilogue::fusion::
|
||||
Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
|
||||
@ -164,10 +205,19 @@ void f8f8bf16_rowwise_impl(
|
||||
Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
|
||||
|
||||
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
||||
|
||||
#if CUTLASS_VERSION == 351
|
||||
#define FLIPPED_SCALES ;
|
||||
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Multiply,
|
||||
WScale,
|
||||
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
|
||||
Multiply,
|
||||
WScale,
|
||||
cutlass::epilogue::fusion::Sm90EVT<Multiply, XScale, Accum>>;
|
||||
#else
|
||||
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Multiply,
|
||||
XScale,
|
||||
cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>;
|
||||
#endif
|
||||
|
||||
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
|
||||
Cast,
|
||||
@ -230,6 +280,7 @@ void f8f8bf16_rowwise_impl(
|
||||
StrideOutput stride_output = cutlass::make_cute_packed_stride(
|
||||
StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
|
||||
|
||||
#ifdef FLIPPED_SCALES
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
@ -245,6 +296,23 @@ void f8f8bf16_rowwise_impl(
|
||||
stride_output,
|
||||
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
|
||||
stride_output}};
|
||||
#else
|
||||
typename Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K},
|
||||
{reinterpret_cast<DtypeA*>(XQ.data_ptr()),
|
||||
stride_a,
|
||||
reinterpret_cast<DtypeB*>(WQ.data_ptr()),
|
||||
stride_b},
|
||||
{{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
|
||||
: nullptr},
|
||||
{{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())},
|
||||
{{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}}},
|
||||
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
|
||||
stride_output,
|
||||
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
|
||||
stride_output}};
|
||||
#endif
|
||||
|
||||
Gemm gemm;
|
||||
|
||||
|
||||
@ -889,9 +889,9 @@ void lstm_miopen(Tensor& output, Tensor& hy, Tensor& cy,
|
||||
int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
|
||||
auto result = _miopen_impl(input, std::make_tuple(hx[0], hx[1]), params, has_biases,
|
||||
miopenLSTM, num_layers, dropout_p, train, bidirectional, batch_first);
|
||||
output = std::move(result.first);
|
||||
hy = std::move(std::get<0>(result.second));
|
||||
cy = std::move(std::get<1>(result.second));
|
||||
output = result.first;
|
||||
hy = std::get<0>(result.second);
|
||||
cy = std::get<1>(result.second);
|
||||
}
|
||||
|
||||
void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy,
|
||||
@ -900,9 +900,9 @@ void lstm_packed_miopen(Tensor& output, Tensor& hy, Tensor& cy,
|
||||
int64_t num_layers, double dropout_p, bool train, bool bidirectional) {
|
||||
auto result = _miopen_impl(data, batch_sizes, std::make_tuple(hx[0], hx[1]),
|
||||
params, has_biases, miopenLSTM, num_layers, dropout_p, train, bidirectional);
|
||||
output = std::move(result.first);
|
||||
hy = std::move(std::get<0>(result.second));
|
||||
cy = std::move(std::get<1>(result.second));
|
||||
output = result.first;
|
||||
hy = std::get<0>(result.second);
|
||||
cy = std::get<1>(result.second);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(lstm_miopen_stub, &lstm_miopen)
|
||||
|
||||
@ -17,8 +17,7 @@
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
namespace xpu {
|
||||
namespace at::native::xpu {
|
||||
|
||||
// result = beta * self + alpha * (mat1 * mat2)
|
||||
Tensor& addmm_out(
|
||||
@ -455,7 +454,7 @@ Tensor& tensordot_out(
|
||||
TORCH_LIBRARY_IMPL(aten, XPU, m) {
|
||||
m.impl("tensordot.out", TORCH_FN(tensordot_out));
|
||||
}
|
||||
} // namespace xpu
|
||||
} // namespace at::native::xpu
|
||||
|
||||
TORCH_IMPL_FUNC(addmm_out_xpu)
|
||||
(const Tensor& self,
|
||||
@ -470,13 +469,11 @@ TORCH_IMPL_FUNC(addmm_out_xpu)
|
||||
|
||||
TORCH_IMPL_FUNC(mm_out_xpu)
|
||||
(const Tensor& self, const Tensor& mat2, const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::mm_out(self, mat2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(bmm_out_xpu)
|
||||
(const Tensor& self, const Tensor& batch2, const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::bmm_out(self, batch2, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
@ -501,13 +498,7 @@ TORCH_IMPL_FUNC(baddbmm_out_xpu)
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
xpu::baddbmm_out(
|
||||
self,
|
||||
batch1,
|
||||
batch2,
|
||||
beta,
|
||||
alpha,
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
const_cast<Tensor&>(result));
|
||||
self, batch1, batch2, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(addmv_out_xpu)
|
||||
@ -517,8 +508,5 @@ TORCH_IMPL_FUNC(addmv_out_xpu)
|
||||
const Scalar& beta,
|
||||
const Scalar& alpha,
|
||||
const Tensor& result) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
|
||||
xpu::addmv_out(self, mat, vec, beta, alpha, const_cast<Tensor&>(result));
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
@ -100,6 +100,8 @@ at::Tensor quantized_convolution(
|
||||
{c10::kXPU, c10::xpu::current_device()});
|
||||
auto stream = GpuStreamManager::Instance().get_stream();
|
||||
|
||||
// create usr_md for tensors, and md for conv primitive
|
||||
dnnl::memory::desc src_md, weight_md, output_md;
|
||||
// input tensors config
|
||||
dnnl::memory::dims src_dims = act.sizes().vec();
|
||||
dnnl::memory::dims weight_dims = weight.sizes().vec();
|
||||
@ -128,8 +130,7 @@ at::Tensor quantized_convolution(
|
||||
|
||||
bool src_need_zp = (act_scale != 0);
|
||||
|
||||
// create usr_md for tensors, and md for conv primitive
|
||||
auto [src_md, weight_md, output_md] =
|
||||
std::tie(src_md, weight_md, output_md) =
|
||||
qconv_get_md(act, weight, output, groups);
|
||||
|
||||
// get tensor md
|
||||
|
||||
@ -870,12 +870,7 @@ id<MTLLibrary> MetalShaderLibrary::compileLibrary(const std::string& src) {
|
||||
const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding];
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
library = [device newLibraryWithSource:str options:options error:&error];
|
||||
if (library == nil) {
|
||||
if ([error domain] == MTLLibraryErrorDomain && [error code] == MTLLibraryErrorCompileFailure) {
|
||||
throw c10::SyntaxError([[error localizedDescription] UTF8String]);
|
||||
}
|
||||
TORCH_CHECK(false, "Failed to create metal library, error: ", [[error description] UTF8String]);
|
||||
}
|
||||
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
|
||||
return library;
|
||||
}
|
||||
|
||||
|
||||
@ -420,7 +420,7 @@ TORCH_API int register_conv_params<3>();
|
||||
|
||||
int register_linear_params() {
|
||||
using SerializationType = std::tuple<at::Tensor, std::optional<at::Tensor>>;
|
||||
[[maybe_unused]] static auto register_linear_params =
|
||||
static auto register_linear_params =
|
||||
torch::selective_class_<LinearPackedParamsBase>(
|
||||
"quantized", TORCH_SELECTIVE_CLASS("LinearPackedParamsBase"))
|
||||
.def_pickle(
|
||||
@ -495,7 +495,7 @@ int register_embedding_params() {
|
||||
std::vector<double>,
|
||||
std::vector<int64_t>>;
|
||||
|
||||
[[maybe_unused]] static auto register_embedding_params =
|
||||
static auto register_embedding_params =
|
||||
torch::selective_class_<EmbeddingPackedParamsBase>(
|
||||
"quantized", TORCH_SELECTIVE_CLASS("EmbeddingPackedParamsBase"))
|
||||
.def_pickle(
|
||||
|
||||
@ -11,6 +11,27 @@
|
||||
// sparsification, as a bitmask.
|
||||
// NOTE: Algorithms might select LESS than 8 values in total in some cases.
|
||||
|
||||
namespace cutlass::platform {
|
||||
template <>
|
||||
struct numeric_limits<cutlass::bfloat16_t> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::bfloat16_t infinity() {
|
||||
return cutlass::bfloat16_t::bitcast(0x7f80);
|
||||
}
|
||||
};
|
||||
|
||||
#if CUTLASS_VERSION == 341
|
||||
template <>
|
||||
struct numeric_limits<cutlass::half_t> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::half_t infinity() {
|
||||
return cutlass::half_t::bitcast(0x7c00);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace cutlass::platform
|
||||
|
||||
namespace at::native {
|
||||
|
||||
template <typename Element, typename Pointwise>
|
||||
|
||||
@ -916,6 +916,7 @@ _flash_attention_forward(
|
||||
std::optional<Tensor> seqused_k = _seqused_k;
|
||||
std::optional<at::Tensor> block_table = std::nullopt; // we are not using the block table yet
|
||||
std::optional<Tensor> alibi_slopes = _alibi_slopes;
|
||||
const float softcap = 0.0;
|
||||
|
||||
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
||||
@ -957,6 +958,7 @@ _flash_attention_forward(
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
softcap,
|
||||
return_debug_mask,
|
||||
std::nullopt /*gen_*/);
|
||||
} else {
|
||||
@ -980,6 +982,7 @@ _flash_attention_forward(
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
softcap,
|
||||
return_debug_mask, /*return_softmax (this is used for testing)*/
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
@ -94,6 +94,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
|
||||
// Currently unused args:
|
||||
std::optional<at::Tensor> alibi_slopes{std::nullopt};
|
||||
const float softcap = 0.0;
|
||||
|
||||
bool determinisitic{false};
|
||||
auto& ctx = at::globalContext();
|
||||
@ -132,6 +133,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
softcap,
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
@ -154,6 +156,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
softcap,
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_causal>
|
||||
struct Alibi {
|
||||
|
||||
const float alibi_slope;
|
||||
const int max_seqlen_k, max_seqlen_q;
|
||||
|
||||
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
|
||||
: alibi_slope(alibi_slope)
|
||||
, max_seqlen_k(max_seqlen_k)
|
||||
, max_seqlen_q(max_seqlen_q) {
|
||||
};
|
||||
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
||||
const int col_idx_offset_,
|
||||
const int row_idx_offset,
|
||||
const int warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // Bias depends on both row_idx and col_idx
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
||||
@ -1,46 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Varlen=true>
|
||||
struct BlockInfo {
|
||||
|
||||
template<typename Params>
|
||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
}
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const int actual_seqlen_q;
|
||||
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
||||
const int seqlen_k_cache;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace pytorch_flash
|
||||
@ -1,96 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct Dropout {
|
||||
|
||||
const unsigned long long seed, offset;
|
||||
const uint8_t p_dropout_in_uint8_t;
|
||||
|
||||
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
|
||||
const uint8_t p_dropout_in_uint8_t,
|
||||
const int bid, const int hid, const int tid, const int nheads)
|
||||
: seed(seed)
|
||||
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
|
||||
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
|
||||
int block_row_start, int block_col_start, int block_row_stride) {
|
||||
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
|
||||
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout()));
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||
};
|
||||
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
#pragma unroll
|
||||
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||
// the random value is less than the threshold.
|
||||
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||
if (!encode_dropout_in_sign_bit
|
||||
&& (std::is_same_v<T, cutlass::half_t> || std::is_same_v<T, cutlass::bfloat16_t>)) {
|
||||
uint16_t rnd_16[16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t mask;
|
||||
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||
tensor_uint32(i) &= mask;
|
||||
}
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||
}
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
}
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||
// // }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
||||
@ -1,190 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
||||
namespace pytorch_flash {
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
using index_t = int64_t;
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
|
||||
// The number of heads.
|
||||
int h, h_k;
|
||||
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
||||
// different from nheads (query).
|
||||
int h_h_k_ratio; // precompute h / h_k,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
void * __restrict__ oaccum_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_batch_stride;
|
||||
index_t o_row_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
// The pointer to the P matrix.
|
||||
void * __restrict__ p_ptr;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
void * __restrict__ softmax_lseaccum_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
// If provided, the actual length of each k sequence.
|
||||
int * __restrict__ seqused_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The K_new and V_new matrices.
|
||||
void * __restrict__ knew_ptr;
|
||||
void * __restrict__ vnew_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t knew_batch_stride;
|
||||
index_t vnew_batch_stride;
|
||||
index_t knew_row_stride;
|
||||
index_t vnew_row_stride;
|
||||
index_t knew_head_stride;
|
||||
index_t vnew_head_stride;
|
||||
|
||||
// The cos and sin matrices for rotary embedding.
|
||||
void * __restrict__ rotary_cos_ptr;
|
||||
void * __restrict__ rotary_sin_ptr;
|
||||
|
||||
// The indices to index into the KV cache.
|
||||
int * __restrict__ cache_batch_idx;
|
||||
|
||||
// Paged KV cache
|
||||
int * __restrict__ block_table;
|
||||
index_t block_table_batch_stride;
|
||||
int page_block_size;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
// uint32_t p_dropout_in_uint;
|
||||
// uint16_t p_dropout_in_uint16_t;
|
||||
uint8_t p_dropout_in_uint8_t;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout).
|
||||
float rp_dropout;
|
||||
float scale_softmax_rp_dropout;
|
||||
|
||||
// Local window size
|
||||
int window_size_left, window_size_right;
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
int64_t * extragraph_offset;
|
||||
int64_t * seed;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
bool is_seqlens_k_cumulative;
|
||||
|
||||
bool is_rotary_interleaved;
|
||||
|
||||
int num_splits; // For split-KV version
|
||||
|
||||
void * __restrict__ alibi_slopes_ptr;
|
||||
index_t alibi_slopes_batch_stride;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_bwd_params : public Flash_fwd_params {
|
||||
|
||||
// The dO and dQKV matrices.
|
||||
void *__restrict__ do_ptr;
|
||||
void *__restrict__ dq_ptr;
|
||||
void *__restrict__ dk_ptr;
|
||||
void *__restrict__ dv_ptr;
|
||||
|
||||
// To accumulate dQ
|
||||
void *__restrict__ dq_accum_ptr;
|
||||
void *__restrict__ dk_accum_ptr;
|
||||
void *__restrict__ dv_accum_ptr;
|
||||
|
||||
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
||||
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
||||
// dv_accum_ptr;
|
||||
|
||||
// The stride between rows of the dO, dQ, dK and dV matrices.
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
index_t do_batch_stride;
|
||||
index_t do_row_stride;
|
||||
index_t do_head_stride;
|
||||
index_t dq_batch_stride;
|
||||
index_t dk_batch_stride;
|
||||
index_t dv_batch_stride;
|
||||
index_t dq_row_stride;
|
||||
index_t dk_row_stride;
|
||||
index_t dv_row_stride;
|
||||
index_t dq_head_stride;
|
||||
index_t dk_head_stride;
|
||||
index_t dv_head_stride;
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void *__restrict__ dsoftmax_sum;
|
||||
|
||||
bool deterministic;
|
||||
index_t dq_accum_split_stride;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
} // namespace pytorch_flash
|
||||
@ -2,6 +2,7 @@
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
|
||||
#include <cstdint>
|
||||
@ -32,13 +33,18 @@
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
// #include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
|
||||
#include <flash.h>
|
||||
#include <namespace_config.h>
|
||||
#include <static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
// #include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
namespace FLASH_NAMESPACE {
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
@ -70,7 +76,9 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
float softmax_scale,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
bool seqlenq_ngroups_swapped=false) {
|
||||
const float softcap,
|
||||
bool seqlenq_ngroups_swapped=false,
|
||||
const bool unpadded_lse=false) {
|
||||
|
||||
// Reset the parameters
|
||||
params = {};
|
||||
@ -126,8 +134,19 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
params.d_rounded = d_rounded;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
|
||||
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
|
||||
#endif
|
||||
if (softcap > 0.0) {
|
||||
params.softcap = softmax_scale / softcap;
|
||||
params.scale_softmax = softcap;
|
||||
params.scale_softmax_log2 = softcap * M_LOG2E;
|
||||
} else{
|
||||
// Remove potential NaN
|
||||
params.softcap = 0.0;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
}
|
||||
|
||||
// Set this to probability of keeping an element to simplify things.
|
||||
params.p_dropout = 1.f - p_dropout;
|
||||
@ -162,6 +181,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
|
||||
#endif
|
||||
params.unpadded_lse = unpadded_lse;
|
||||
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
|
||||
}
|
||||
|
||||
void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
@ -195,7 +216,9 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
float softmax_scale,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
bool deterministic) {
|
||||
const float softcap,
|
||||
bool deterministic,
|
||||
const bool unpadded_lse) {
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
||||
@ -208,7 +231,10 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
softcap,
|
||||
false, // seqlenq_ngroups_swapped
|
||||
unpadded_lse);
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.do_ptr = dout.data_ptr();
|
||||
@ -244,11 +270,13 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
HEADDIM_SWITCH(params.d, [&] {
|
||||
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
|
||||
}
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -357,6 +385,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
|
||||
@ -396,6 +425,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
||||
|
||||
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
@ -441,7 +472,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
@ -476,11 +507,14 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
softcap
|
||||
);
|
||||
|
||||
|
||||
// Keep references to these tensors to extend their lifetime
|
||||
auto [softmax_lse_accum, out_accum] = set_params_splitkv(params, batch_size, num_heads,
|
||||
at::Tensor softmax_lse_accum, out_accum;
|
||||
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
|
||||
head_size, seqlen_k, seqlen_q,
|
||||
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
|
||||
|
||||
@ -497,26 +531,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
|
||||
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
params.seed = seed_t.data_ptr<int64_t>();
|
||||
params.extragraph_offset = offset_t.data_ptr<int64_t>();
|
||||
}
|
||||
seed_t = at::empty({2}, at::TensorOptions().dtype(c10::kUInt64).device(at::kCUDA));
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(seed_t.data_ptr());
|
||||
params.philox_args = philox_state;
|
||||
} else {
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||
}
|
||||
|
||||
}else{
|
||||
seed_t = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
}
|
||||
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
@ -556,6 +576,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
|
||||
@ -604,6 +625,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
const int head_size_og = sizes[2];
|
||||
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
|
||||
|
||||
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
||||
|
||||
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
|
||||
const int num_blocks = !paged_KV ? 0 : k.size(0);
|
||||
const int page_block_size = !paged_KV ? 1 : k.size(1);
|
||||
@ -667,7 +690,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
||||
@ -679,7 +701,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
@ -689,7 +711,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
auto softmax_lse = at::empty({num_heads, total_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor p;
|
||||
// Only return softmax if there's dropout to reduce compilation time
|
||||
if (return_softmax) {
|
||||
@ -720,7 +742,10 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
seqlenq_ngroups_swapped);
|
||||
softcap,
|
||||
seqlenq_ngroups_swapped,
|
||||
/*unpadded_lse*/true);
|
||||
params.total_q = total_q;
|
||||
if (paged_KV) {
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
@ -740,9 +765,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
at::Tensor seed_t, offset_t;
|
||||
if (p_dropout > 0.0) {
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
@ -750,26 +775,12 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
|
||||
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
params.seed = seed_t.data_ptr<int64_t>();
|
||||
params.extragraph_offset = offset_t.data_ptr<int64_t>();
|
||||
}
|
||||
seed_t = at::empty({2}, at::TensorOptions().dtype(c10::kUInt64).device(at::kCUDA));
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(seed_t.data_ptr());
|
||||
params.philox_args = philox_state;
|
||||
} else {
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||
}
|
||||
|
||||
seed_t = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
}
|
||||
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
@ -788,7 +799,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
|
||||
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
|
||||
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
|
||||
@ -797,7 +808,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_bwd_<elem_type, kHeadDim>(params, stream);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -818,6 +831,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
@ -877,7 +891,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
@ -976,21 +990,17 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic);
|
||||
softcap,
|
||||
deterministic,
|
||||
/*unpadded_lse*/false);
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
|
||||
if (is_dropout) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(
|
||||
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
params.rng_state = philox_seed.data_ptr<uint64_t>();
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
@ -1019,7 +1029,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
|
||||
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
@ -1034,6 +1044,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset)
|
||||
@ -1099,7 +1110,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
@ -1154,7 +1165,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
auto softmax_d = at::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
if (loop) {
|
||||
// We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
|
||||
@ -1165,6 +1176,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
|
||||
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
|
||||
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
|
||||
// Same holds for softmax_d, since LSE is stored in unpadded format.
|
||||
if (!deterministic) {
|
||||
dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
} else {
|
||||
@ -1210,21 +1222,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic);
|
||||
softcap,
|
||||
deterministic,
|
||||
/*unpadded_lse*/true);
|
||||
params.total_q = total_q;;
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(
|
||||
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
params.rng_state = philox_seed.data_ptr<uint64_t>();
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
@ -1265,6 +1273,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
||||
int num_splits
|
||||
) {
|
||||
@ -1375,7 +1384,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
@ -1403,7 +1412,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
/*p_dropout=*/0.f,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
softcap
|
||||
);
|
||||
|
||||
at::Tensor k, v, k_padded, v_padded;
|
||||
if (k_.has_value()) {
|
||||
@ -1486,7 +1497,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
}
|
||||
|
||||
// Keep references to these tensors to extend their lifetime
|
||||
auto [softmax_lse_accum, out_accum] = set_params_splitkv(params, batch_size, num_heads,
|
||||
at::Tensor softmax_lse_accum, out_accum;
|
||||
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads,
|
||||
head_size, seqlen_k, seqlen_q,
|
||||
head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
|
||||
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
#include <namespace_config.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
namespace FLASH_NAMESPACE {
|
||||
|
||||
TORCH_API
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
@ -18,6 +19,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
|
||||
@ -39,6 +41,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
|
||||
@ -59,6 +62,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
@ -84,8 +88,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
} // namespace pytorch_flash
|
||||
} // namespace FLASH_NAMESPACE
|
||||
|
||||
@ -1,827 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/PhiloxUtils.cuh>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/softmax.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/mask.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/dropout.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/alibi.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_N,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
|
||||
constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
||||
// Divide by 2 because right now we always use 2 for the ValLayout
|
||||
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
|
||||
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
||||
// This gives the correct layout, idk why.
|
||||
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
|
||||
// Stride<Stride<_1, _64>, _8> >{},
|
||||
// auto t = make_tile(Layout<Shape<_8, _2, _2>,
|
||||
// Stride<_1, _64, _8> >{},
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
||||
Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
|
||||
make_layout(Int<TileShape_K>{}));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_N,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;
|
||||
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
||||
// Divide by 2 because right now we always use 2 for the ValLayout
|
||||
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
|
||||
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
||||
auto t = make_tile(make_layout(Int<TileShape_M>{}),
|
||||
Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
||||
Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value;
|
||||
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
|
||||
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
||||
|
||||
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
|
||||
if (Is_local) {
|
||||
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
|
||||
}
|
||||
|
||||
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
|
||||
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
|
||||
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
|
||||
+ (m_block_max - 1) * kBlockM;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||
+ (m_block_max - 1) * kBlockM;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.k_row_stride, _1{}));
|
||||
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.v_row_stride, _1{}));
|
||||
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.dq_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutQdO{});
|
||||
Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
|
||||
Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
|
||||
// Double buffer for sQ
|
||||
Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{});
|
||||
Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
|
||||
Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(),
|
||||
typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
|
||||
Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{});
|
||||
Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{});
|
||||
Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK),
|
||||
typename Kernel_traits::SmemLayoutPdS{});
|
||||
Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
|
||||
Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
|
||||
Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{});
|
||||
Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
|
||||
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
|
||||
// sP and sdQ share the same memory so be careful
|
||||
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||
using GmemTiledCopydO = std::conditional_t<
|
||||
Is_first,
|
||||
typename Kernel_traits::GmemTiledCopydO,
|
||||
typename Kernel_traits::GmemTiledCopyQKV
|
||||
>;
|
||||
GmemTiledCopydO gmem_tiled_copy_dO;
|
||||
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
|
||||
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
!Seq_parallel,
|
||||
typename Kernel_traits::GmemTiledCopydQaccum,
|
||||
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd
|
||||
>;
|
||||
GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;
|
||||
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
|
||||
Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO);
|
||||
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
|
||||
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
||||
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
||||
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
||||
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
|
||||
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
|
||||
// if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
|
||||
// __syncthreads();
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) {
|
||||
// printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data());
|
||||
// }
|
||||
|
||||
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
|
||||
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
|
||||
auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);
|
||||
Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N)
|
||||
Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N)
|
||||
Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N)
|
||||
Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)
|
||||
|
||||
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
|
||||
auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx);
|
||||
Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N)
|
||||
Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N)
|
||||
|
||||
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
|
||||
//
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
|
||||
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
|
||||
Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
|
||||
Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);
|
||||
|
||||
// auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
|
||||
auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);
|
||||
Tensor tSsK = smem_thr_copy_KV.partition_S(sK);
|
||||
// if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); }
|
||||
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
|
||||
Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);
|
||||
|
||||
// Partition sP and sdS to match the accumulator partitioning
|
||||
// This has to be tiled_mma_sdp, not tiled_mma_dkv
|
||||
// auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
|
||||
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
|
||||
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
// if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); }
|
||||
// if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); }
|
||||
// if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) {
|
||||
// printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data());
|
||||
// }
|
||||
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx);
|
||||
Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);
|
||||
Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);
|
||||
|
||||
auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
|
||||
Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);
|
||||
Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);
|
||||
|
||||
auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx);
|
||||
Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);
|
||||
|
||||
auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);
|
||||
Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);
|
||||
|
||||
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ);
|
||||
Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV);
|
||||
|
||||
// Allocate predicate tensors for k
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||
|
||||
// Set predicates for k bounds
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
|
||||
}
|
||||
|
||||
// Prologue
|
||||
|
||||
// We'll advance gdQ and gdQaccum before the 1st read/write.
|
||||
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
|
||||
|
||||
int m_block = m_block_max - 1;
|
||||
int m_block_min = (!Is_causal && !Is_local)
|
||||
? 0
|
||||
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
|
||||
// If not local, we're guaranteed that m_block_min <= m_block:
|
||||
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
|
||||
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
|
||||
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
|
||||
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
|
||||
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
|
||||
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
|
||||
// However, if local, then this possible to have some blocks of K & V not attending to any query.
|
||||
// We might need to exit early and write 0 to dK and dV for those blocks.
|
||||
// Otherwise we get wrong result for the case where we don't enter the for loop.
|
||||
// And we might read OOB elements from gQ and gdO.
|
||||
// This also covers the case where actual_seqlen_q == 0
|
||||
if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
|
||||
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
||||
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
||||
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dk_row_stride, _1{}));
|
||||
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dv_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
|
||||
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
||||
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
||||
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
||||
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
||||
clear(tdKrdK);
|
||||
clear(tdVrdV);
|
||||
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
||||
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
|
||||
tQsQ.data() = tQsQ.data() + size(sQ);
|
||||
tSsQ.data() = tSsQ.data() + size(sQ);
|
||||
tdKsQt.data() = tdKsQt.data() + size(sQ);
|
||||
}
|
||||
|
||||
if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }
|
||||
|
||||
if (Kernel_traits::Is_V_in_regs) {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::cp_async_fence();
|
||||
}
|
||||
|
||||
Tensor tdOrdO = make_fragment_like(tdOgdO);
|
||||
Tensor tdOrO = make_fragment_like(tdOgO);
|
||||
if (!Is_first) {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
} else {
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
|
||||
Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N)
|
||||
static_assert(decltype(size<0>(taccScS))::value == 4);
|
||||
// Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices.
|
||||
Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
const int row = get<0>(taccScS_row(mi));
|
||||
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
|
||||
}
|
||||
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
|
||||
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
|
||||
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
|
||||
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
|
||||
|
||||
// Tensor tKrK = make_fragment_like(tKsK);
|
||||
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
|
||||
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
|
||||
// // if (cute::thread(1, 0)) { print(tKrK); }
|
||||
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
if (!Kernel_traits::Is_V_in_regs) {
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
pytorch_flash::cp_async_fence();
|
||||
|
||||
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
|
||||
if (Is_first) {
|
||||
cute::copy(tdOrdO, tdOsdO);
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
}
|
||||
|
||||
if (Kernel_traits::Is_V_in_regs) {
|
||||
cute::cp_async_wait<1>();
|
||||
__syncthreads();
|
||||
Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M
|
||||
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
|
||||
}
|
||||
|
||||
const auto [seed, offset] = at::cuda::philox::unpack(params.philox_args);
|
||||
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
|
||||
bidb, bidh, tidx, params.h);
|
||||
|
||||
clear(acc_dv);
|
||||
clear(acc_dk);
|
||||
|
||||
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
pytorch_flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
|
||||
|
||||
for (; m_block >= m_block_min; --m_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
|
||||
clear(acc_s);
|
||||
cute::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
Tensor dP_sum = make_fragment_like(lse);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
|
||||
|
||||
// if (cute::thread0()) { print(sK); }
|
||||
// Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK);
|
||||
// #pragma unroll
|
||||
// for (int k = 0; k < size<2>(tSrK_copy_view); ++k) {
|
||||
// cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
|
||||
// }
|
||||
// if (cute::thread0()) { print(tSrK); }
|
||||
pytorch_flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
|
||||
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread(32, 0)) { print(scores); }
|
||||
|
||||
if (Has_alibi) {
|
||||
alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
|
||||
}
|
||||
|
||||
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
|
||||
// actual_seqlen_k, because acc_s would be some finite value for those indices.
|
||||
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
|
||||
// so the result would still be correct.
|
||||
// However, it's possible that the values in acc_s are so large that they overflow
|
||||
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
|
||||
// So we need to mask out the elements beyond actual_seqlen_k.
|
||||
if (!Is_causal && !Is_local) {
|
||||
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
|
||||
pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k,
|
||||
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
|
||||
}
|
||||
} else if (Is_causal) {
|
||||
// Putting this causal masking right after acc_s is *much* slower for some reason.
|
||||
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
|
||||
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
|
||||
// But we still want to mask out elements beyond actual_seqlen_k.
|
||||
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|
||||
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
|
||||
pytorch_flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
binfo.actual_seqlen_q,
|
||||
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
|
||||
AtomLayoutMS * 16);
|
||||
}
|
||||
} else if (Is_local) {
|
||||
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|
||||
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
|
||||
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
|
||||
pytorch_flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
binfo.actual_seqlen_q, AtomLayoutMS * 16,
|
||||
params.window_size_left, params.window_size_right);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// if (cute::thread(32, 0)) { print(scores); }
|
||||
// Compute the exponential value.
|
||||
pytorch_flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
|
||||
if constexpr (Is_dropout) {
|
||||
int warp_id = tidx / 32;
|
||||
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
|
||||
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
|
||||
static_assert(MMA_N_SdP % 2 == 0);
|
||||
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
|
||||
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
acc_s, block_row_idx, block_col_idx, AtomLayoutMS
|
||||
);
|
||||
}
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = !Is_dropout
|
||||
? pytorch_flash::convert_type<Element>(acc_s)
|
||||
: pytorch_flash::convert_type_relu<Element>(acc_s);
|
||||
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
|
||||
// if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
|
||||
Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaSdP>(rP.layout()));
|
||||
Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
|
||||
// if (cute::thread0()) { print(tPaP); }
|
||||
// __syncthreads();
|
||||
// if (cute::thread0()) { print(sP); }
|
||||
|
||||
Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
|
||||
CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA
|
||||
|
||||
clear(acc_dp);
|
||||
// Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), pytorch_flash::convert_layout_acc_rowcol(acc_dp.layout()));
|
||||
// #pragma unroll
|
||||
// for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) {
|
||||
// #pragma unroll
|
||||
// for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) {
|
||||
// acc_dp_reshaped(mi, ni) = -dP_sum(mi);
|
||||
// }
|
||||
// }
|
||||
|
||||
// if (cute::thread0()) { print(dP_sum); }
|
||||
|
||||
pytorch_flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
|
||||
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
|
||||
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
|
||||
);
|
||||
|
||||
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
|
||||
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
|
||||
auto pointwise_mult = [](float p, float dp, float d) {
|
||||
return p * (!Is_dropout || p >= 0 ? dp - d : d);
|
||||
};
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(dS); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(dS); ++ni) {
|
||||
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) { print(dS); }
|
||||
|
||||
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded));
|
||||
if (Is_first || Seq_parallel) {
|
||||
clear(acc_dq);
|
||||
} else {
|
||||
// Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum
|
||||
Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),
|
||||
make_layout(get<0>(acc_dq.layout()),
|
||||
get<2>(acc_dq.layout()),
|
||||
get<1>(acc_dq.layout())));
|
||||
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped);
|
||||
}
|
||||
|
||||
if (Double_buffer && m_block > m_block_min) {
|
||||
// Double buffer for sQ
|
||||
const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);
|
||||
tQsQ.data() = tQsQ.data() + sQ_offset;
|
||||
tSsQ.data() = tSsQ.data() + sQ_offset;
|
||||
// Advance gQ
|
||||
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
|
||||
pytorch_flash::cp_async_fence();
|
||||
}
|
||||
|
||||
Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
|
||||
// Convert dS from fp32 to fp16
|
||||
Tensor tdSrdS = pytorch_flash::convert_type<Element>(dS_reshaped);
|
||||
// if (cute::thread0()) { print(tPrP); }
|
||||
Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
|
||||
__syncthreads();
|
||||
|
||||
// Layout p_l = tPrP.layout();
|
||||
// Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
|
||||
// pytorch_flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
|
||||
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
|
||||
// pytorch_flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
|
||||
pytorch_flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
|
||||
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
|
||||
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
|
||||
// if (cute::thread0()) { print(acc_dv); }
|
||||
|
||||
__syncthreads(); // Need syncthreads since we're writing to the same sdO location
|
||||
|
||||
if (m_block > m_block_min) {
|
||||
// Advance gdO
|
||||
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
|
||||
if (Is_first) {
|
||||
tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
|
||||
} else {
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);
|
||||
pytorch_flash::cp_async_fence();
|
||||
}
|
||||
}
|
||||
|
||||
pytorch_flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
|
||||
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
|
||||
// if (cute::thread0()) { print(acc_dq); }
|
||||
|
||||
if (m_block > m_block_min) {
|
||||
gLSE.data() = gLSE.data() + (-int(kBlockM));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
|
||||
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
|
||||
}
|
||||
|
||||
if (!Is_last) {
|
||||
// Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum
|
||||
Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),
|
||||
make_layout(get<0>(acc_dq.layout()),
|
||||
get<2>(acc_dq.layout()),
|
||||
get<1>(acc_dq.layout())));
|
||||
if (!Seq_parallel) {
|
||||
cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum);
|
||||
} else {
|
||||
// if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); }
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
|
||||
// Convert acc_dq from fp32 to fp16
|
||||
Tensor rdQ = pytorch_flash::convert_type<Element>(acc_dq);
|
||||
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
|
||||
}
|
||||
|
||||
pytorch_flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
|
||||
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
|
||||
// if (cute::thread0()) { print(acc_dk); }
|
||||
if (Double_buffer) { // Double buffer for sQ
|
||||
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
|
||||
}
|
||||
if (!Double_buffer && m_block > m_block_min) {
|
||||
__syncthreads();
|
||||
// Advance gQ
|
||||
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
|
||||
pytorch_flash::cp_async_fence();
|
||||
}
|
||||
|
||||
if (Is_first && m_block > m_block_min) {
|
||||
cute::copy(tdOrdO, tdOsdO);
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
}
|
||||
|
||||
if (Is_last) {
|
||||
__syncthreads();
|
||||
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
|
||||
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
|
||||
tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride));
|
||||
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tdQgdQ); ++m) {
|
||||
if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
|
||||
cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
|
||||
if (Is_dropout) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; }
|
||||
|
||||
// Convert acc_dv from fp32 to fp16
|
||||
Tensor rdK = pytorch_flash::convert_type<Element>(acc_dk);
|
||||
Tensor rdV = pytorch_flash::convert_type<Element>(acc_dv);
|
||||
|
||||
Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
|
||||
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
|
||||
|
||||
// Partition sdV and sdK to match the accumulator partitioning
|
||||
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
// We need syncthreads here since we're writing to the same location as sK and sV.
|
||||
// Without syncthreads, some thread might modify the location of sK while another thread
|
||||
// is reading it for dQ gemm, leading to a race condition.
|
||||
// If Is_last, there's already a __syncthreads() at the end of the loop.
|
||||
if (!Is_last) { __syncthreads(); }
|
||||
|
||||
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
|
||||
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
|
||||
|
||||
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
||||
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
||||
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dk_row_stride, _1{}));
|
||||
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dv_row_stride, _1{}));
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
|
||||
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
||||
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
||||
|
||||
__syncthreads();
|
||||
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
||||
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
|
||||
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
||||
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
|
||||
Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
||||
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
if (n_block_max == 1) {
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
|
||||
} else {
|
||||
// Iterating backward from n_block_max - 1 to 0 might save 1 register
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
|
||||
for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
|
||||
}
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
|
||||
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
|
||||
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace flash
|
||||
@ -1,338 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define ARCH_SUPPORTS_FLASH
|
||||
#endif
|
||||
|
||||
#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
|
||||
defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
|
||||
#define KERNEL_PARAM_MODIFIER __grid_constant__
|
||||
#else
|
||||
#define KERNEL_PARAM_MODIFIER
|
||||
#endif
|
||||
|
||||
// Define a macro for unsupported architecture handling to centralize the error message
|
||||
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
||||
|
||||
// Use a macro to clean up kernel definitions
|
||||
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
|
||||
template<typename Kernel_traits, __VA_ARGS__> \
|
||||
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
pytorch_flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
pytorch_flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits>
|
||||
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::clear_dKVaccum<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
|
||||
pytorch_flash::convert_dQ<Kernel_traits>(params, nsplits);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::convert_dKV<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
int gridDimx = num_n_block;
|
||||
if (params.deterministic) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
|
||||
}
|
||||
dim3 grid_n(gridDimx, params.b, params.h);
|
||||
|
||||
if (!params.deterministic) {
|
||||
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else {
|
||||
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
|
||||
// a multiple of kBlockN, we'll need to apply mask in the loop.
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
}
|
||||
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
#ifndef FLASHATTENTION_DISABLE_BACKWARD
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
|
||||
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
} else { // 96 KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// This has a lot of register spilling
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
// } else {
|
||||
// }
|
||||
}
|
||||
});
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
|
||||
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
if constexpr(!Is_dropout) { // 92KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else { // 116 KB
|
||||
// This is faster for dropout since we don't have many registers to spare
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 136 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 224;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 176 * 1024) { // H100
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
|
||||
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
|
||||
if constexpr (!Is_dropout) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
}; // namespace pytorch_flash
|
||||
@ -1,377 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
|
||||
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
|
||||
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
|
||||
// Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
|
||||
// The last coordinate is the "page".
|
||||
Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
|
||||
make_layout(get<0>(do_.layout()),
|
||||
get<2>(do_.layout()))));
|
||||
Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
|
||||
Tensor do_fp32 = pytorch_flash::convert_type<float>(do_reshaped);
|
||||
Tensor o_fp32 = pytorch_flash::convert_type<float>(o_reshaped);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
|
||||
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
|
||||
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
|
||||
}
|
||||
pytorch_flash::SumOp<float> sum_op;
|
||||
dP_sum_cur = pytorch_flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
|
||||
if (threadIdx.x % THREADS_PER_ROW == 0) {
|
||||
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
|
||||
inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
|
||||
|
||||
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
|
||||
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
|
||||
// TODO: careful, we're zeroing out dQaccum with type float4, but when
|
||||
// we do atomicAdds, we use type float. The layouts are different. Check this.
|
||||
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
|
||||
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
||||
|
||||
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
|
||||
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
|
||||
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
|
||||
|
||||
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
|
||||
|
||||
// Allocate predicate tensors for k
|
||||
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
|
||||
// Set predicates for k bounds
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
|
||||
|
||||
Tensor tdOrdO = make_fragment_like(tdOgdO);
|
||||
Tensor tdOrO = make_fragment_like(tdOgO);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
|
||||
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
|
||||
// so that (dP - dP_sum) is on the same scale.
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
if (Clear_dQaccum) {
|
||||
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
|
||||
// do atomicAdds on.
|
||||
Tensor zero = make_fragment_like(tdQgdQaccum);
|
||||
clear(zero);
|
||||
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void clear_dKVaccum(const Params ¶ms) {
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int n_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
||||
|
||||
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
|
||||
|
||||
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
|
||||
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
|
||||
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
|
||||
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
|
||||
Tensor zero = make_fragment_like(tdKgdKaccum);
|
||||
clear(zero);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert dQ from dQaccum (in float) to fp16/bf16.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
|
||||
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.dq_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
|
||||
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutdQ{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
|
||||
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
|
||||
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
||||
|
||||
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
|
||||
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
|
||||
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
|
||||
|
||||
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
|
||||
|
||||
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
|
||||
clear(acc_dq);
|
||||
for (int s = 0; s < nsplits; ++s) {
|
||||
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
|
||||
// Convert acc_dq from fp32 to fp16
|
||||
Tensor rdQ = pytorch_flash::convert_type<Element>(acc_dq);
|
||||
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
|
||||
__syncthreads();
|
||||
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
|
||||
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
|
||||
|
||||
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
|
||||
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_q.
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void convert_dKV(const Params ¶ms) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const int n_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
||||
|
||||
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
||||
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
||||
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
|
||||
+ n_block * kBlockN) * params.d_rounded;
|
||||
|
||||
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dk_row_stride, _1{}));
|
||||
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dv_row_stride, _1{}));
|
||||
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutdKV{});
|
||||
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
|
||||
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
|
||||
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
|
||||
|
||||
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
|
||||
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
||||
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
||||
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
|
||||
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
|
||||
|
||||
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
|
||||
|
||||
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
|
||||
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dk); ++i) {
|
||||
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dv); ++i) {
|
||||
acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
|
||||
}
|
||||
// Convert acc_dk from fp32 to fp16
|
||||
Tensor rdK = pytorch_flash::convert_type<Element>(acc_dk);
|
||||
Tensor rdV = pytorch_flash::convert_type<Element>(acc_dv);
|
||||
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
|
||||
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
|
||||
__syncthreads();
|
||||
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
||||
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
||||
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
|
||||
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
|
||||
|
||||
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
||||
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,378 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContextLight.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define ARCH_SUPPORTS_FLASH
|
||||
#endif
|
||||
|
||||
#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
|
||||
defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
|
||||
#define KERNEL_PARAM_MODIFIER __grid_constant__
|
||||
#else
|
||||
#define KERNEL_PARAM_MODIFIER
|
||||
#endif
|
||||
|
||||
// Define a macro for unsupported architecture handling to centralize the error message
|
||||
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
||||
|
||||
// Use a macro to clean up kernel definitions
|
||||
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
|
||||
template<typename Kernel_traits, __VA_ARGS__> \
|
||||
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
|
||||
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // Enforce constraints
|
||||
pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
pytorch_flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
|
||||
static_assert(Log_max_splits >= 1);
|
||||
pytorch_flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
// printf("smem_size = %d\n", smem_size);
|
||||
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.b, params.h);
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
|
||||
static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
||||
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
if (params.num_splits > 1) {
|
||||
// We want kBlockM to be as small as possible for more parallelism.
|
||||
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
|
||||
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
|
||||
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
|
||||
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
if (params.num_splits <= 2) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 4) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 8) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 16) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 32) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 64) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 128) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int kBlockM = 64; // Fixed for all head dimensions
|
||||
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
||||
// and for headdim 192 with block size 64 x 128.
|
||||
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
|
||||
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
||||
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
// Using block size (64 x 256) is 27% slower for seqlen=2k
|
||||
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// These two are always slower
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// 1st ones are good for H100, A100
|
||||
// 2nd one is good for A6000 bc we get slightly better occupancy
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
// and 128 x 64 with 8 warps is the fastest for non-causal.
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 224;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
|
||||
// If we have N = 32, there are only 1024 elements to load at once, where each load
|
||||
// is 8 elements. This means we can only use 128 threads and not 256 threads.
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_sm, max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
|
||||
status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, we want to run with 128 x 64 (128KB smem).
|
||||
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// 64 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// 96 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
}; // namespace pytorch_flash
|
||||
@ -1,344 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace pytorch_flash{
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_kernel_traits {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
static constexpr bool Has_cp_async = true;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
static constexpr bool Has_cp_async = false;
|
||||
#endif
|
||||
|
||||
using ElementAccum = float;
|
||||
using index_t = int64_t;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<elem_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||
#else
|
||||
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||
#endif
|
||||
};
|
||||
|
||||
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_fwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
Tile<Int<16 * kNWarps>, _16, _16>>;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
// https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
|
||||
using SmemLayoutVtransposed = decltype(
|
||||
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
|
||||
|
||||
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||
// to the same banks.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
using GmemLayoutAtomOaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
|
||||
Stride< _8, _1>>,
|
||||
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
|
||||
using GmemTiledCopyRotcossin = decltype(
|
||||
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
|
||||
using GmemTiledCopyRotcossinCont = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
|
||||
};
|
||||
|
||||
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressure.
|
||||
// No_double_buffer is another option to reduce smem usage, but will slow things down.
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
|
||||
int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
|
||||
bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_bwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Is_V_in_regs = Is_V_in_regs_;
|
||||
static constexpr bool No_double_buffer = No_double_buffer_;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
|
||||
static_assert(kNWarps % AtomLayoutMSdP == 0);
|
||||
static_assert(kNWarps % AtomLayoutNdKV == 0);
|
||||
static_assert(kNWarps % AtomLayoutMdQ == 0);
|
||||
|
||||
using TiledMmaSdP = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
|
||||
Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
|
||||
using TiledMmadKV = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
|
||||
Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
|
||||
using TiledMmadQ = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
|
||||
Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
|
||||
using SmemLayoutAtomQdO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQdO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdO{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
// SmemLayoutAtomQdO{},
|
||||
SmemLayoutAtomKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutKtransposed = decltype(
|
||||
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
|
||||
|
||||
// TODO: generalize to other values of kBlockN
|
||||
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||
// static constexpr int kPBlockN = kBlockN;
|
||||
// Temporarily disabling this for hdim 256 on sm86 and sm89
|
||||
// static_assert(kBlockN >= 64);
|
||||
static_assert(kBlockN >= 32);
|
||||
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
|
||||
static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
|
||||
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
|
||||
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
|
||||
static constexpr int kSwizzlePdS = 3;
|
||||
using SmemLayoutAtomPdS = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
|
||||
Stride<Int<kPBlockN>, _1>>{}));
|
||||
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdS{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||
using SmemLayoutPdStransposed = decltype(
|
||||
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
|
||||
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
using SmemLayoutQdOtransposed = decltype(
|
||||
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
|
||||
|
||||
using SmemLayoutAtomdKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutdKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
using SmemLayoutAtomdQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutdQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdQ{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
// Double buffer for sQ
|
||||
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
|
||||
static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemSize = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
|
||||
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
|
||||
// to affect speed in practice.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopydO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
|
||||
Stride< _8, _1>>,
|
||||
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopydQaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
GmemLayoutAtomdQaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
|
||||
using GmemTiledCopydQaccumAtomicAdd = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
|
||||
Stride<_32, _1>>{},
|
||||
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim224<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user