mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Update to AOTriton 0.8b (#140172)
Notable new features for SDPA operators on AMD systems from AOTriton 0.8b: 1. Nestedtensor support; 2. MQA/GQA support; 3. Restore Efficient attention support for causal=True and seqlen_q != seqlen_k cases; + The kernel should use top-left alignment, bottom right alignment will be added later 4. Move gfx1100 (RX7900/W7800/W7900) out of experimental support status. However, users are strongly recommended to update to ROCM 6.2.4, notably for its firmware updates. Related unit tests are enabled as well. Notable related changes from AOTriton 0.8b: 1. AOTriton 0.8b moves the GPU kernel out of libaotriton.so to a separate directory `aotriton.images`; 2. LZMA replaces ZSTD as GPU kernel compression algorithm for better compression ratio: aotriton0.8b (.so + aotriton.images take 350MB) compared to aotriton0.7b .so: 800MB 3. The compression cannot be disabled now, and `liblzma` is hard run-time dependency. + Should not be a problem, since `lzma` is part of Python Standard Library Pull Request resolved: https://github.com/pytorch/pytorch/pull/140172 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
0a619a212f
commit
424156c26c
@ -1,5 +1,5 @@
|
|||||||
0.7b
|
0.8b
|
||||||
manylinux_2_17
|
manylinux_2_28
|
||||||
rocm6.2
|
rocm6.2
|
||||||
9be04068c3c0857a4cfd17d7e39e71d0423ebac2
|
6f8cbcac8a92775291bb1ba8f514d4beb350baf4
|
||||||
3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291
|
e938def5d32869fe2e00aec0300f354c9f157867bebdf2e104d732b94cb238d8
|
||||||
|
@ -253,11 +253,11 @@ make_wheel_record() {
|
|||||||
FPATH=$1
|
FPATH=$1
|
||||||
if echo $FPATH | grep RECORD >/dev/null 2>&1; then
|
if echo $FPATH | grep RECORD >/dev/null 2>&1; then
|
||||||
# if the RECORD file, then
|
# if the RECORD file, then
|
||||||
echo "$FPATH,,"
|
echo "\"$FPATH\",,"
|
||||||
else
|
else
|
||||||
HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g')
|
HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g')
|
||||||
FSIZE=$(ls -nl $FPATH | awk '{print $5}')
|
FSIZE=$(ls -nl $FPATH | awk '{print $5}')
|
||||||
echo "$FPATH,sha256=$HASH,$FSIZE"
|
echo "\"$FPATH\",sha256=$HASH,$FSIZE"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -225,11 +225,11 @@ make_wheel_record() {
|
|||||||
FPATH=$1
|
FPATH=$1
|
||||||
if echo $FPATH | grep RECORD >/dev/null 2>&1; then
|
if echo $FPATH | grep RECORD >/dev/null 2>&1; then
|
||||||
# if the RECORD file, then
|
# if the RECORD file, then
|
||||||
echo "$FPATH,,"
|
echo "\"$FPATH\",,"
|
||||||
else
|
else
|
||||||
HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g')
|
HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g')
|
||||||
FSIZE=$(ls -nl $FPATH | awk '{print $5}')
|
FSIZE=$(ls -nl $FPATH | awk '{print $5}')
|
||||||
echo "$FPATH,sha256=$HASH,$FSIZE"
|
echo "\"$FPATH\",sha256=$HASH,$FSIZE"
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -266,6 +266,20 @@ RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC))
|
|||||||
DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/})
|
DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/})
|
||||||
DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/})
|
DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/})
|
||||||
|
|
||||||
|
# PyTorch 2.6+ (AOTriton 0.8b+)
|
||||||
|
# AKS = "AOTriton Kernel Storage", a file format to store GPU kernels compactly
|
||||||
|
if (( $(echo "${PYTORCH_VERSION} 2.6" | awk '{print ($1 >= $2)}') )); then
|
||||||
|
LIBAOTRITON_DIR=$(find "$ROCM_HOME/lib/" -name "libaotriton_v2.so" -printf '%h\n')
|
||||||
|
if [[ -z ${LIBAOTRITON_DIR} ]]; then
|
||||||
|
LIBAOTRITON_DIR=$(find "$ROCM_HOME/" -name "libaotriton_v2.so" -printf '%h\n')
|
||||||
|
fi
|
||||||
|
AKS_FILES=($(find "${LIBAOTRITON_DIR}/aotriton.images" -type f -name '*.aks?' -printf '%P\n'))
|
||||||
|
AKS_SRC="${LIBAOTRITON_DIR}/aotriton.images"
|
||||||
|
AKS_DST="lib/aotriton.images"
|
||||||
|
DEPS_AUX_SRCLIST+=(${AKS_FILES[@]/#/${AKS_SRC}/})
|
||||||
|
DEPS_AUX_DSTLIST+=(${AKS_FILES[@]/#/${AKS_DST}/})
|
||||||
|
fi
|
||||||
|
|
||||||
echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}"
|
echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}"
|
||||||
|
|
||||||
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
|
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"
|
||||||
|
@ -1160,6 +1160,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
|
|||||||
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||||
|
|
||||||
using aotriton::v2::flash::attn_fwd;
|
using aotriton::v2::flash::attn_fwd;
|
||||||
|
using aotriton::v2::flash::attn_fwd_compact_varlen;
|
||||||
using sdp::aotriton_adapter::mk_aotensor;
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||||
using sdp::aotriton_adapter::mk_philoxtensor;
|
using sdp::aotriton_adapter::mk_philoxtensor;
|
||||||
@ -1172,22 +1173,46 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
|
|||||||
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
|
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
|
||||||
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
|
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
|
||||||
hipError_t err; // TODO: Error handling
|
hipError_t err; // TODO: Error handling
|
||||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
if (seqstart_q.has_value()) {
|
||||||
mk_aotensor(k_t, "k"),
|
// varlen aka nested tensor
|
||||||
mk_aotensor(v_t, "v"),
|
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
|
||||||
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
|
mk_aotensor(k_t, "k"),
|
||||||
softmax_scale,
|
mk_aotensor(v_t, "v"),
|
||||||
mk_aotensor<2>(softmax_lse, "M"),
|
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
|
||||||
mk_aotensor(output_t, "Out"),
|
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
|
||||||
dropout_p,
|
max_seqlen_q,
|
||||||
seed,
|
max_seqlen_k,
|
||||||
offset1,
|
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
|
||||||
offset2,
|
softmax_scale,
|
||||||
seed_output,
|
mk_aotensor<2>(softmax_lse, "M"),
|
||||||
offset_output,
|
mk_aotensor(output_t, "Out"),
|
||||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
dropout_p,
|
||||||
is_causal,
|
seed,
|
||||||
stream);
|
offset1,
|
||||||
|
offset2,
|
||||||
|
seed_output,
|
||||||
|
offset_output,
|
||||||
|
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||||
|
is_causal,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||||
|
mk_aotensor(k_t, "k"),
|
||||||
|
mk_aotensor(v_t, "v"),
|
||||||
|
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
|
||||||
|
softmax_scale,
|
||||||
|
mk_aotensor<2>(softmax_lse, "M"),
|
||||||
|
mk_aotensor(output_t, "Out"),
|
||||||
|
dropout_p,
|
||||||
|
seed,
|
||||||
|
offset1,
|
||||||
|
offset2,
|
||||||
|
seed_output,
|
||||||
|
offset_output,
|
||||||
|
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||||
|
is_causal,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
if (!compute_logsumexp) {
|
if (!compute_logsumexp) {
|
||||||
// Set the tensor to empty when compute_logsumexp is false
|
// Set the tensor to empty when compute_logsumexp is false
|
||||||
logsumexp = at::empty(
|
logsumexp = at::empty(
|
||||||
|
@ -441,29 +441,57 @@ _efficient_attention_backward(
|
|||||||
|
|
||||||
hipError_t err;
|
hipError_t err;
|
||||||
using aotriton::v2::flash::attn_bwd;
|
using aotriton::v2::flash::attn_bwd;
|
||||||
|
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
||||||
using sdp::aotriton_adapter::mk_aotensor;
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||||
using sdp::aotriton_adapter::cast_dtype;
|
using sdp::aotriton_adapter::cast_dtype;
|
||||||
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
|
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
|
||||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
if (cu_seqlens_q.has_value()) {
|
||||||
mk_aotensor(k_t, "k"),
|
// varlen aka Nested tensor
|
||||||
mk_aotensor(v_t, "v"),
|
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
|
||||||
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
mk_aotensor(k_t, "k"),
|
||||||
softmax_scale,
|
mk_aotensor(v_t, "v"),
|
||||||
mk_aotensor(out_t, "out"),
|
mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"),
|
||||||
mk_aotensor(dout_t, "dout"),
|
mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"),
|
||||||
mk_aotensor(dq_t, "dq"),
|
max_seqlen_q,
|
||||||
mk_aotensor(dk_t, "dk"),
|
max_seqlen_k,
|
||||||
mk_aotensor(dv_t, "dv"),
|
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
||||||
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
|
softmax_scale,
|
||||||
mk_aotensor<2>(softmax_lse, "L"),
|
mk_aotensor(out_t, "out"),
|
||||||
mk_aotensor<2>(delta, "delta"),
|
mk_aotensor(dout_t, "dout"),
|
||||||
float(dropout_p),
|
mk_aotensor(dq_t, "dq"),
|
||||||
mk_aoscalartensor(philox_seed),
|
mk_aotensor(dk_t, "dk"),
|
||||||
mk_aoscalartensor(philox_offset),
|
mk_aotensor(dv_t, "dv"),
|
||||||
0,
|
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
|
||||||
is_causal,
|
mk_aotensor<2>(softmax_lse, "L"),
|
||||||
stream);
|
mk_aotensor<2>(delta, "delta"),
|
||||||
|
float(dropout_p),
|
||||||
|
mk_aoscalartensor(philox_seed),
|
||||||
|
mk_aoscalartensor(philox_offset),
|
||||||
|
0,
|
||||||
|
is_causal,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||||
|
mk_aotensor(k_t, "k"),
|
||||||
|
mk_aotensor(v_t, "v"),
|
||||||
|
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
|
||||||
|
softmax_scale,
|
||||||
|
mk_aotensor(out_t, "out"),
|
||||||
|
mk_aotensor(dout_t, "dout"),
|
||||||
|
mk_aotensor(dq_t, "dq"),
|
||||||
|
mk_aotensor(dk_t, "dk"),
|
||||||
|
mk_aotensor(dv_t, "dv"),
|
||||||
|
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
|
||||||
|
mk_aotensor<2>(softmax_lse, "L"),
|
||||||
|
mk_aotensor<2>(delta, "delta"),
|
||||||
|
float(dropout_p),
|
||||||
|
mk_aoscalartensor(philox_seed),
|
||||||
|
mk_aoscalartensor(philox_offset),
|
||||||
|
0,
|
||||||
|
is_causal,
|
||||||
|
stream);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
at::Tensor workspace;
|
at::Tensor workspace;
|
||||||
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
|
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
|
||||||
|
@ -103,6 +103,10 @@ int64_t minimum_gemm_alignment(sdp_params const& params) {
|
|||||||
return matmul_alignment_mn;
|
return matmul_alignment_mn;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// On ROCM, ME and FA share the backend, and hence they share the checking
|
||||||
|
// function for fundamental limitations by the GPU kernel
|
||||||
|
// caller_is_meff is added to make the TORCH_WARN message showing the correct result
|
||||||
|
template<bool caller_is_meff = false>
|
||||||
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
|
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
|
||||||
// All head_dim sizes must be equal and less than 256
|
// All head_dim sizes must be equal and less than 256
|
||||||
const auto max_size = c10::SymInt(256);
|
const auto max_size = c10::SymInt(256);
|
||||||
@ -114,7 +118,8 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
|
|||||||
if (!(same_head_dim_size && (query_size_last <= max_size))) {
|
if (!(same_head_dim_size && (query_size_last <= max_size))) {
|
||||||
if (debug) {
|
if (debug) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
"Flash attention requires q,k,v to have the same last dimension and to be less than or equal to 256.",
|
caller_is_meff ? "Efficient attention on ROCM" : "Flash attention",
|
||||||
|
" requires q,k,v to have the same last dimension and to be less than or equal to 256.",
|
||||||
" Got Query.size(-1): ",
|
" Got Query.size(-1): ",
|
||||||
query_size_last,
|
query_size_last,
|
||||||
", Key.size(-1): ",
|
", Key.size(-1): ",
|
||||||
@ -128,6 +133,8 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// See check_head_dim_size_flash above for the purpose of caller_is_meff
|
||||||
|
template<bool caller_is_meff = false>
|
||||||
bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) {
|
bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) {
|
||||||
const auto max_size = c10::SymInt(256);
|
const auto max_size = c10::SymInt(256);
|
||||||
const auto query_size_last = params.query.sym_size(-1);
|
const auto query_size_last = params.query.sym_size(-1);
|
||||||
@ -139,7 +146,9 @@ bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) {
|
|||||||
(query_size_last <= max_size))) {
|
(query_size_last <= max_size))) {
|
||||||
if (debug) {
|
if (debug) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
"For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
|
"For NestedTensor inputs,",
|
||||||
|
caller_is_meff ? " Efficient attention on ROCM " : " Flash attention",
|
||||||
|
" requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
|
||||||
" Got Query.size(-1): ",
|
" Got Query.size(-1): ",
|
||||||
query_size_last,
|
query_size_last,
|
||||||
", Key.size(-1): ",
|
", Key.size(-1): ",
|
||||||
@ -208,7 +217,6 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
|||||||
// Check that the gpu is capable of running flash attention
|
// Check that the gpu is capable of running flash attention
|
||||||
using sm80 = SMVersion<8, 0>;
|
using sm80 = SMVersion<8, 0>;
|
||||||
using sm90 = SMVersion<9, 0>;
|
using sm90 = SMVersion<9, 0>;
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
||||||
#if USE_ROCM
|
#if USE_ROCM
|
||||||
#if USE_AOTRITON
|
#if USE_AOTRITON
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
@ -220,19 +228,11 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
c10::string_view arch(dprops->gcnArchName);
|
|
||||||
if (arch == "gfx1100") {
|
|
||||||
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
|
|
||||||
if (!enable_navi3x) {
|
|
||||||
TORCH_WARN_ONCE("Flash attention support on Navi31 GPU is still experimental."
|
|
||||||
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
if (!check_sm_version<sm80, sm90>(dprops)) {
|
if (!check_sm_version<sm80, sm90>(dprops)) {
|
||||||
if (debug) {
|
if (debug) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
@ -252,7 +252,6 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
|||||||
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
|
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
|
||||||
using sm50 = SMVersion<5, 0>;
|
using sm50 = SMVersion<5, 0>;
|
||||||
using sm90 = SMVersion<9, 0>;
|
using sm90 = SMVersion<9, 0>;
|
||||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
|
||||||
#if USE_ROCM
|
#if USE_ROCM
|
||||||
#if USE_AOTRITON
|
#if USE_AOTRITON
|
||||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||||
@ -264,19 +263,11 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
c10::string_view arch(dprops->gcnArchName);
|
|
||||||
if (arch == "gfx1100") {
|
|
||||||
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
|
|
||||||
if (!enable_navi3x) {
|
|
||||||
TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental."
|
|
||||||
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
if (!check_sm_version<sm50, sm90>(dprops)) {
|
if (!check_sm_version<sm50, sm90>(dprops)) {
|
||||||
if (debug) {
|
if (debug) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
@ -615,7 +606,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
|||||||
check_all_tensors_on_device,
|
check_all_tensors_on_device,
|
||||||
check_tensor_shapes,
|
check_tensor_shapes,
|
||||||
check_for_attn_mask,
|
check_for_attn_mask,
|
||||||
check_head_dim_size_flash,
|
check_head_dim_size_flash<false /*caller_is_meff*/>,
|
||||||
check_flash_attention_hardware_support,
|
check_flash_attention_hardware_support,
|
||||||
check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
|
check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
|
||||||
check_flash_causal_non_square_seqlens,
|
check_flash_causal_non_square_seqlens,
|
||||||
@ -629,7 +620,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
|||||||
if (has_for_nested_inputs(params)) {
|
if (has_for_nested_inputs(params)) {
|
||||||
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
||||||
check_batch_size_nested,
|
check_batch_size_nested,
|
||||||
check_head_dim_size_flash_nested,
|
check_head_dim_size_flash_nested<false /*caller_is_meff*/>,
|
||||||
check_for_seq_len_0_nested_tensor);
|
check_for_seq_len_0_nested_tensor);
|
||||||
for (auto& constraint : nested_constraints) {
|
for (auto& constraint : nested_constraints) {
|
||||||
if (!constraint(params, debug)) {
|
if (!constraint(params, debug)) {
|
||||||
@ -637,11 +628,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#if USE_ROCM
|
|
||||||
constexpr bool backend_supports_grouped_query_attention = false;
|
|
||||||
#else
|
|
||||||
constexpr bool backend_supports_grouped_query_attention = true;
|
constexpr bool backend_supports_grouped_query_attention = true;
|
||||||
#endif
|
|
||||||
if (has_only_dense_inputs(params)) {
|
if (has_only_dense_inputs(params)) {
|
||||||
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
||||||
check_batch_size_and_num_heads_dense<backend_supports_grouped_query_attention>,
|
check_batch_size_and_num_heads_dense<backend_supports_grouped_query_attention>,
|
||||||
@ -679,7 +666,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
|||||||
check_mem_efficient_hardware_support,
|
check_mem_efficient_hardware_support,
|
||||||
check_tensor_shapes,
|
check_tensor_shapes,
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
check_head_dim_size_flash
|
check_head_dim_size_flash<true /* caller_is_meff */>
|
||||||
#else
|
#else
|
||||||
check_head_dim_size_mem_efficient
|
check_head_dim_size_mem_efficient
|
||||||
#endif
|
#endif
|
||||||
@ -691,12 +678,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (has_for_nested_inputs(params)) {
|
if (has_for_nested_inputs(params)) {
|
||||||
#ifdef USE_ROCM
|
|
||||||
TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention.");
|
|
||||||
return false;
|
|
||||||
#endif
|
|
||||||
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
|
||||||
|
#ifndef USE_ROCM // ME and FA shares backend on ROCM and thus supports training
|
||||||
check_requires_grad_and_nested,
|
check_requires_grad_and_nested,
|
||||||
|
#else // Meanwhile ME on ROCM share the limits of FA about head dimensions
|
||||||
|
check_head_dim_size_flash_nested<true /* caller_is_meff */>,
|
||||||
|
#endif
|
||||||
check_batch_size_nested,
|
check_batch_size_nested,
|
||||||
check_for_seq_len_0_nested_tensor);
|
check_for_seq_len_0_nested_tensor);
|
||||||
for (auto& constraint : nested_constraints) {
|
for (auto& constraint : nested_constraints) {
|
||||||
|
@ -77,6 +77,37 @@ void check_gpu_arch(hipStream_t stream) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We want to checkpoint and save the RNG state for backward if dropout
|
||||||
|
// We get the default generator and return the seed and offset which will
|
||||||
|
// be used in the backward function
|
||||||
|
std::tuple<at::Tensor, at::Tensor, at::PhiloxCudaState, bool>
|
||||||
|
prepare_philox_arguments(float p_dropout, int64_t counter_offset) {
|
||||||
|
at::Tensor seed_t, offset_t;
|
||||||
|
at::PhiloxCudaState philox_state;
|
||||||
|
bool use_philox_state = false;
|
||||||
|
if (p_dropout <= 0.0) {
|
||||||
|
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
return { seed_t, offset_t, philox_state, use_philox_state };
|
||||||
|
}
|
||||||
|
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||||
|
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||||
|
philox_state = gen->philox_cuda_state(counter_offset);
|
||||||
|
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||||
|
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||||
|
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
} else {
|
||||||
|
// See Note [CUDA Graph-safe RNG states] about the design
|
||||||
|
use_philox_state = true;
|
||||||
|
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||||
|
}
|
||||||
|
|
||||||
|
return { seed_t, offset_t, philox_state, use_philox_state };
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||||
@ -158,44 +189,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
|||||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||||
const int head_size = round_multiple(head_size_og, 8);
|
const int head_size = round_multiple(head_size_og, 8);
|
||||||
const int head_size_rounded = round_multiple(head_size, 32);
|
const int head_size_rounded = round_multiple(head_size, 32);
|
||||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
|
||||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
|
||||||
|
|
||||||
// We want to checkpoint and save the RNG state for backward if dropout
|
auto [seed_t, offset_t, philox_state, use_philox_state] =
|
||||||
// We get the default generator and return the seed and offset which will
|
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
|
||||||
// be used in the backward function
|
|
||||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
|
||||||
at::Tensor seed_t, offset_t;
|
|
||||||
|
|
||||||
at::PhiloxCudaState philox_state;
|
|
||||||
bool use_philox_state = false;
|
|
||||||
if (p_dropout > 0.0) {
|
|
||||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
|
||||||
// state
|
|
||||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
|
||||||
int64_t counter_offset = batch_size * num_heads * 32;
|
|
||||||
// See Note [Acquire lock when using random generators]
|
|
||||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
|
||||||
philox_state = gen->philox_cuda_state(counter_offset);
|
|
||||||
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
|
||||||
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
|
||||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong).device(at::kCUDA));
|
|
||||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong).device(at::kCUDA));
|
|
||||||
} else {
|
|
||||||
// See Note [CUDA Graph-safe RNG states] about the design
|
|
||||||
use_philox_state = true;
|
|
||||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
|
||||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
|
||||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
|
||||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
|
||||||
} else {
|
|
||||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
|
||||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transpose tensors to meet AOTriton's Flash API
|
// Transpose tensors to meet AOTriton's Flash API
|
||||||
at::Tensor q_t = q_padded.permute({0,2,1,3});
|
at::Tensor q_t = q_padded.permute({0,2,1,3});
|
||||||
@ -215,7 +211,6 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
|||||||
|
|
||||||
hipError_t err; // TODO: Error handling
|
hipError_t err; // TODO: Error handling
|
||||||
using aotriton::v2::flash::attn_fwd;
|
using aotriton::v2::flash::attn_fwd;
|
||||||
using aotriton::TensorView;
|
|
||||||
using sdp::aotriton_adapter::mk_aotensor;
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||||
using sdp::aotriton_adapter::mk_philoxtensor;
|
using sdp::aotriton_adapter::mk_philoxtensor;
|
||||||
@ -266,16 +261,150 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
|||||||
int window_size_right,
|
int window_size_right,
|
||||||
const bool return_softmax,
|
const bool return_softmax,
|
||||||
std::optional<at::Generator> gen_) {
|
std::optional<at::Generator> gen_) {
|
||||||
|
TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt");
|
||||||
|
const bool paged_KV = block_table_.has_value();
|
||||||
|
TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt");
|
||||||
|
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
|
||||||
|
|
||||||
TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm");
|
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||||
|
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||||
|
check_gpu_arch(stream);
|
||||||
|
|
||||||
at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat));
|
auto q_dtype = q.dtype();
|
||||||
at::Tensor p = at::empty({}, at::dtype(at::kFloat));
|
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||||
at::Tensor offset_t = at::empty({}, at::dtype(at::kLong));
|
"FlashAttention only support fp16 and bf16 data type");
|
||||||
at::Tensor seed_t = at::empty({}, at::dtype(at::kLong));
|
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||||
at::Tensor out = at::empty({}, at::dtype(at::kFloat));
|
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||||
|
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
|
||||||
|
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
|
||||||
|
|
||||||
return {out, q, k, v, softmax_lse, seed_t, offset_t, p};
|
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||||
|
CHECK_DEVICE(cu_seqlens_q);
|
||||||
|
CHECK_DEVICE(cu_seqlens_k);
|
||||||
|
|
||||||
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
CHECK_CONTIGUOUS(cu_seqlens_q);
|
||||||
|
CHECK_CONTIGUOUS(cu_seqlens_k);
|
||||||
|
|
||||||
|
const auto sizes = q.sizes();
|
||||||
|
|
||||||
|
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||||
|
int num_heads = sizes[1];
|
||||||
|
const int head_size_og = sizes[2];
|
||||||
|
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
|
||||||
|
|
||||||
|
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) {
|
||||||
|
is_causal = false;
|
||||||
|
} // causal=true is the same as causal=false in this case
|
||||||
|
|
||||||
|
at::Tensor temp_q = q;
|
||||||
|
const int total_q = temp_q.sizes()[0];
|
||||||
|
|
||||||
|
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||||
|
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||||
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||||
|
|
||||||
|
if (window_size_left >= max_seqlen_k) {
|
||||||
|
window_size_left = -1;
|
||||||
|
}
|
||||||
|
if (window_size_right >= max_seqlen_k) {
|
||||||
|
window_size_right = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
|
||||||
|
const int total_k = k.size(0);
|
||||||
|
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||||
|
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||||
|
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||||
|
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||||
|
|
||||||
|
// AOTriton's varlen API needs input shapes be
|
||||||
|
// (1, num_heads, total sequence lenght, head dimension)
|
||||||
|
at::Tensor q_padded, k_padded, v_padded;
|
||||||
|
at::Tensor out, out_padded;
|
||||||
|
q_padded = q.unsqueeze(0).transpose(1, 2);
|
||||||
|
k_padded = k.unsqueeze(0).transpose(1, 2);
|
||||||
|
v_padded = v.unsqueeze(0).transpose(1, 2);
|
||||||
|
if (out_.has_value()) {
|
||||||
|
out = out_.value();
|
||||||
|
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||||
|
CHECK_DEVICE(out);
|
||||||
|
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||||
|
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||||
|
} else {
|
||||||
|
out = at::empty_like(q);
|
||||||
|
}
|
||||||
|
out_padded = out.unsqueeze(0).transpose(1, 2);
|
||||||
|
|
||||||
|
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||||
|
const int head_size = head_size_og;
|
||||||
|
|
||||||
|
auto opts = q.options();
|
||||||
|
|
||||||
|
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||||
|
at::Tensor M = softmax_lse.view({batch_size * num_heads, max_seqlen_q});
|
||||||
|
at::Tensor softmax_fa_t;
|
||||||
|
// Only return softmax if there's dropout to reduce compilation time
|
||||||
|
if (return_softmax) {
|
||||||
|
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
|
||||||
|
softmax_fa_t = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts);
|
||||||
|
} else {
|
||||||
|
softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (zero_tensors) {
|
||||||
|
out.zero_();
|
||||||
|
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
||||||
|
if (return_softmax) {
|
||||||
|
softmax_fa_t.zero_();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [seed_t, offset_t, philox_state, use_philox_state] =
|
||||||
|
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
|
||||||
|
|
||||||
|
if (max_seqlen_k > 0) {
|
||||||
|
hipError_t err; // TODO: Error handling
|
||||||
|
using aotriton::v2::flash::attn_fwd_compact_varlen;
|
||||||
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
|
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||||
|
using sdp::aotriton_adapter::mk_philoxtensor;
|
||||||
|
using sdp::aotriton_adapter::cast_dtype;
|
||||||
|
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||||
|
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
|
||||||
|
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
|
||||||
|
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
|
||||||
|
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
|
||||||
|
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
|
||||||
|
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||||
|
mk_aotensor(k_padded, "k"),
|
||||||
|
mk_aotensor(v_padded, "v"),
|
||||||
|
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
|
||||||
|
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
empty_bias,
|
||||||
|
softmax_scale,
|
||||||
|
mk_aotensor<2>(M, "M"),
|
||||||
|
mk_aotensor(out_padded, "Out"),
|
||||||
|
p_dropout,
|
||||||
|
seed,
|
||||||
|
offset1,
|
||||||
|
offset2,
|
||||||
|
seed_output,
|
||||||
|
offset_output,
|
||||||
|
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||||
|
is_causal,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
||||||
|
out.zero_();
|
||||||
|
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
||||||
|
}
|
||||||
|
|
||||||
|
return {out, q, k, v, softmax_lse, seed_t, offset_t, softmax_fa_t};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||||
@ -297,6 +426,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|||||||
const bool deterministic,
|
const bool deterministic,
|
||||||
const at::Tensor philox_seed,
|
const at::Tensor philox_seed,
|
||||||
const at::Tensor philox_offset) {
|
const at::Tensor philox_offset) {
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||||
check_gpu_arch(stream);
|
check_gpu_arch(stream);
|
||||||
|
|
||||||
@ -341,8 +473,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|||||||
|
|
||||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||||
const int head_size_rounded = round_multiple(head_size, 32);
|
const int head_size_rounded = round_multiple(head_size, 32);
|
||||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
|
||||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
|
||||||
|
|
||||||
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
||||||
|
|
||||||
@ -381,23 +511,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|||||||
dv = at::empty_like(k);
|
dv = at::empty_like(k);
|
||||||
}
|
}
|
||||||
|
|
||||||
// const at::Tensor& dout_padded = dout;
|
|
||||||
|
|
||||||
// Otherwise the kernel will be launched from cuda:0 device
|
|
||||||
// Cast to char to avoid compiler warning about narrowing
|
|
||||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
|
||||||
|
|
||||||
auto opts = q.options();
|
auto opts = q.options();
|
||||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||||
|
|
||||||
at::Tensor dk_expanded, dv_expanded;
|
|
||||||
if (num_heads_k != num_heads) { // MQA / GQA
|
|
||||||
dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
|
||||||
dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
|
||||||
} else {
|
|
||||||
dk_expanded = dk;
|
|
||||||
dv_expanded = dv;
|
|
||||||
}
|
|
||||||
|
|
||||||
at::Tensor q_t = q.permute({0,2,1,3});
|
at::Tensor q_t = q.permute({0,2,1,3});
|
||||||
at::Tensor k_t = k.permute({0,2,1,3});
|
at::Tensor k_t = k.permute({0,2,1,3});
|
||||||
@ -440,14 +555,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
|||||||
stream);
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For MQA/GQA we need to sum dK and dV across the groups
|
|
||||||
if (num_heads_k != num_heads) {
|
|
||||||
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
|
||||||
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
|
||||||
}
|
|
||||||
return { dq, dk, dv, softmax_d };
|
return { dq, dk, dv, softmax_d };
|
||||||
#undef CALL_BWD_DROPOUT
|
|
||||||
#undef CALL_BWD
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||||
@ -473,13 +581,173 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|||||||
int window_size_right,
|
int window_size_right,
|
||||||
const bool deterministic,
|
const bool deterministic,
|
||||||
const at::Tensor philox_seed,
|
const at::Tensor philox_seed,
|
||||||
const at::Tensor philox_offset) {
|
const at::Tensor philox_offset)
|
||||||
TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm");
|
{
|
||||||
|
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
|
||||||
|
|
||||||
at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat));
|
if (is_causal) {
|
||||||
|
window_size_right = 0;
|
||||||
|
}
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||||
|
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||||
|
check_gpu_arch(stream);
|
||||||
|
|
||||||
return { q, k, v, softmax_d };
|
bool is_dropout = p_dropout > 0.0;
|
||||||
|
|
||||||
|
auto q_dtype = q.dtype();
|
||||||
|
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||||
|
"FlashAttention only support fp16 and bf16 data type");
|
||||||
|
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||||
|
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||||
|
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||||
|
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||||
|
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
|
||||||
|
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
|
||||||
|
|
||||||
|
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
||||||
|
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
||||||
|
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
|
||||||
|
|
||||||
|
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||||
|
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||||
|
CHECK_CONTIGUOUS(cu_seqlens_q);
|
||||||
|
CHECK_CONTIGUOUS(cu_seqlens_k);
|
||||||
|
|
||||||
|
const auto sizes = q.sizes();
|
||||||
|
|
||||||
|
const int total_q = sizes[0];
|
||||||
|
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||||
|
const int num_heads = sizes[1];
|
||||||
|
const int head_size_og = dout.size(2);
|
||||||
|
const int head_size = sizes[2];
|
||||||
|
const int total_k = k.size(0);
|
||||||
|
const int num_heads_k = k.size(1);
|
||||||
|
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||||
|
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
||||||
|
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||||
|
|
||||||
|
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
||||||
|
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
||||||
|
|
||||||
|
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||||
|
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
||||||
|
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
||||||
|
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||||
|
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
|
||||||
|
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||||
|
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||||
|
|
||||||
|
at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous();
|
||||||
|
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||||
|
|
||||||
|
at::Tensor q_padded, k_padded, v_padded;
|
||||||
|
q_padded = q.unsqueeze(0).transpose(1, 2);
|
||||||
|
k_padded = k.unsqueeze(0).transpose(1, 2);
|
||||||
|
v_padded = v.unsqueeze(0).transpose(1, 2);
|
||||||
|
at::Tensor out_t, dout_t;
|
||||||
|
out_t = out.unsqueeze(0).transpose(1, 2);
|
||||||
|
dout_t = dout.unsqueeze(0).transpose(1, 2);
|
||||||
|
|
||||||
|
at::Tensor dq, dk, dv;
|
||||||
|
at::Tensor dq_padded, dk_padded, dv_padded;
|
||||||
|
if (dq_.has_value()) {
|
||||||
|
dq = dq_.value();
|
||||||
|
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||||
|
CHECK_DEVICE(dq);
|
||||||
|
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||||
|
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
||||||
|
} else {
|
||||||
|
dq = at::empty_like(q);
|
||||||
|
}
|
||||||
|
if (dk_.has_value()) {
|
||||||
|
dk = dk_.value();
|
||||||
|
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||||
|
CHECK_DEVICE(dk);
|
||||||
|
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||||
|
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
|
||||||
|
} else {
|
||||||
|
dk = at::empty_like(k);
|
||||||
|
}
|
||||||
|
if (dv_.has_value()) {
|
||||||
|
dv = dv_.value();
|
||||||
|
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||||
|
CHECK_DEVICE(dv);
|
||||||
|
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||||
|
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
|
||||||
|
} else {
|
||||||
|
dv = at::empty_like(v);
|
||||||
|
}
|
||||||
|
dq_padded = dq.unsqueeze(0).transpose(1, 2);
|
||||||
|
dk_padded = dk.unsqueeze(0).transpose(1, 2);
|
||||||
|
dv_padded = dv.unsqueeze(0).transpose(1, 2);
|
||||||
|
|
||||||
|
auto opts = q.options();
|
||||||
|
auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||||
|
|
||||||
|
if( zero_tensors ) {
|
||||||
|
dq.zero_();
|
||||||
|
dk.zero_();
|
||||||
|
dv.zero_();
|
||||||
|
softmax_d.zero_();
|
||||||
|
}
|
||||||
|
|
||||||
|
at::PhiloxCudaState philox_args;
|
||||||
|
if (is_dropout) {
|
||||||
|
if (at::cuda::currentStreamCaptureStatus() ==
|
||||||
|
at::cuda::CaptureStatus::None)
|
||||||
|
{
|
||||||
|
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||||
|
} else { // dropout + capture
|
||||||
|
philox_args = at::PhiloxCudaState(
|
||||||
|
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (max_seqlen_q > 0) {
|
||||||
|
hipError_t err; // TODO: Error handling
|
||||||
|
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
||||||
|
using sdp::aotriton_adapter::mk_aotensor;
|
||||||
|
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||||
|
using sdp::aotriton_adapter::cast_dtype;
|
||||||
|
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||||
|
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||||
|
mk_aotensor(k_padded, "k"),
|
||||||
|
mk_aotensor(v_padded, "v"),
|
||||||
|
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
|
||||||
|
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
empty_bias,
|
||||||
|
softmax_scale,
|
||||||
|
mk_aotensor(out_t, "out"),
|
||||||
|
mk_aotensor(dout_t, "dout"),
|
||||||
|
mk_aotensor(dq_padded, "dq"),
|
||||||
|
mk_aotensor(dk_padded, "dk"),
|
||||||
|
mk_aotensor(dv_padded, "dv"),
|
||||||
|
empty_bias,
|
||||||
|
mk_aotensor<2>(softmax_lse_cont, "L"),
|
||||||
|
mk_aotensor<2>(delta, "delta"),
|
||||||
|
p_dropout,
|
||||||
|
mk_aoscalartensor(philox_seed),
|
||||||
|
mk_aoscalartensor(philox_offset),
|
||||||
|
0,
|
||||||
|
is_causal,
|
||||||
|
stream);
|
||||||
|
} else {
|
||||||
|
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||||
|
dq.zero_();
|
||||||
|
dk.zero_();
|
||||||
|
dv.zero_();
|
||||||
|
softmax_d.zero_();
|
||||||
|
}
|
||||||
|
|
||||||
|
return { dq, dk, dv, softmax_d };
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pytorch_fmha
|
} // namespace pytorch_fmha
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
7
setup.py
7
setup.py
@ -1373,6 +1373,13 @@ def main():
|
|||||||
"lib/*.lib",
|
"lib/*.lib",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
aotriton_image_path = os.path.join(lib_path, "aotriton.images")
|
||||||
|
aks2_files = []
|
||||||
|
for root, dirs, files in os.walk(aotriton_image_path):
|
||||||
|
subpath = os.path.relpath(root, start=aotriton_image_path)
|
||||||
|
for fn in files:
|
||||||
|
aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn))
|
||||||
|
torch_package_data += aks2_files
|
||||||
if get_cmake_cache_vars()["USE_TENSORPIPE"]:
|
if get_cmake_cache_vars()["USE_TENSORPIPE"]:
|
||||||
torch_package_data.extend(
|
torch_package_data.extend(
|
||||||
[
|
[
|
||||||
|
@ -389,8 +389,6 @@ class TestFlexAttention(InductorTestCase):
|
|||||||
KV_S = Q_S
|
KV_S = Q_S
|
||||||
if V_D is None:
|
if V_D is None:
|
||||||
V_D = Q_D
|
V_D = Q_D
|
||||||
if TEST_WITH_ROCM and Q_H != KV_H:
|
|
||||||
self.skipTest("enable_gqa=True is unsupported on ROCM, for now")
|
|
||||||
q = torch.randn(
|
q = torch.randn(
|
||||||
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
|
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
|
||||||
)
|
)
|
||||||
@ -565,9 +563,6 @@ class TestFlexAttention(InductorTestCase):
|
|||||||
V_D: int = D,
|
V_D: int = D,
|
||||||
block_mask: Optional[BlockMask] = None,
|
block_mask: Optional[BlockMask] = None,
|
||||||
):
|
):
|
||||||
if TEST_WITH_ROCM and Q_H != KV_H:
|
|
||||||
self.skipTest("enable_gqa=True is unsupported on ROCM, for now")
|
|
||||||
|
|
||||||
assert Q_H % KV_H == 0
|
assert Q_H % KV_H == 0
|
||||||
|
|
||||||
q = torch.randn(
|
q = torch.randn(
|
||||||
|
@ -22,7 +22,7 @@ from torch.nn.attention.flex_attention import (
|
|||||||
from torch.testing import FileCheck
|
from torch.testing import FileCheck
|
||||||
from torch.testing._internal import common_utils
|
from torch.testing._internal import common_utils
|
||||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
|
||||||
from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM
|
from torch.testing._internal.common_utils import skipIfRocm
|
||||||
from torch.utils._triton import has_triton
|
from torch.utils._triton import has_triton
|
||||||
|
|
||||||
|
|
||||||
@ -492,9 +492,6 @@ class TestFlexDecoding(InductorTestCase):
|
|||||||
V_D: int = D,
|
V_D: int = D,
|
||||||
block_mask: Optional[BlockMask] = None,
|
block_mask: Optional[BlockMask] = None,
|
||||||
):
|
):
|
||||||
if TEST_WITH_ROCM and Q_H != KV_H:
|
|
||||||
self.skipTest("enable_gqa=True is unsupported on ROCM, for now")
|
|
||||||
|
|
||||||
assert Q_H % KV_H == 0
|
assert Q_H % KV_H == 0
|
||||||
|
|
||||||
q = torch.randn(
|
q = torch.randn(
|
||||||
|
@ -277,8 +277,8 @@ class TestMHADeviceType(TestCase):
|
|||||||
def test_native_multihead_self_attention(self, device, dtype, use_nt,
|
def test_native_multihead_self_attention(self, device, dtype, use_nt,
|
||||||
need_weights, average_attn_weights, use_padding, pad_all, fused):
|
need_weights, average_attn_weights, use_padding, pad_all, fused):
|
||||||
if TEST_WITH_ROCM:
|
if TEST_WITH_ROCM:
|
||||||
if use_nt:
|
if use_nt and use_padding and pad_all:
|
||||||
self.skipTest("ROCM does not support nested tensors for Flash Attention for now.")
|
self.skipTest("Large numerical errors on ROCM to investigate.")
|
||||||
if use_padding and not pad_all and fused:
|
if use_padding and not pad_all and fused:
|
||||||
self.skipTest("Large numerical errors on ROCM to investigate.")
|
self.skipTest("Large numerical errors on ROCM to investigate.")
|
||||||
for need_weights in (False, not pad_all):
|
for need_weights in (False, not pad_all):
|
||||||
|
@ -1630,7 +1630,6 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
q, k, v, None, 0.0, False))
|
q, k, v, None, 0.0, False))
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipIfRocm(msg='enable_gqa=True unsupported')
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||||
@parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION])
|
@parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION])
|
||||||
def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel):
|
def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel):
|
||||||
@ -1646,7 +1645,6 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
is_causal=False, enable_gqa=True)
|
is_causal=False, enable_gqa=True)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@skipIfRocm(msg='enable_gqa=True unsupported')
|
|
||||||
def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device):
|
def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device):
|
||||||
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
||||||
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
|
||||||
@ -1784,7 +1782,6 @@ class TestSDPAFailureModes(NNTestCase):
|
|||||||
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@skipIfRocm # Missing EFFICIENT_ATTENTION
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
def test_fused_kernels_nested_broadcasting_error_cases(self, device):
|
def test_fused_kernels_nested_broadcasting_error_cases(self, device):
|
||||||
# one of k,v needs to be broadcasted and other has non consistent seq_len dim
|
# one of k,v needs to be broadcasted and other has non consistent seq_len dim
|
||||||
@ -2485,7 +2482,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
|
|
||||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
|
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
@skipIfRocm # No cuDNN Attention
|
@skipIfRocm(msg="No cuDNN on ROCm")
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
|
||||||
def test_fused_attention_different_dk_dv(self, device):
|
def test_fused_attention_different_dk_dv(self, device):
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
@ -2526,7 +2523,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
|
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
|
||||||
o = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
o = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
@skipIfRocm # No cuDNN Attention
|
@skipIfRocm(msg="No cuDNN on ROCm")
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
|
||||||
def test_cudnn_attention_trivial_output_transpose(self, device):
|
def test_cudnn_attention_trivial_output_transpose(self, device):
|
||||||
# see also: https://github.com/pytorch/pytorch/issues/134001
|
# see also: https://github.com/pytorch/pytorch/issues/134001
|
||||||
@ -2713,8 +2710,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
@parametrize("type", ["dense", "nested"])
|
@parametrize("type", ["dense", "nested"])
|
||||||
@parametrize("is_contiguous", [True, False])
|
@parametrize("is_contiguous", [True, False])
|
||||||
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
|
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
|
||||||
if TEST_WITH_ROCM and type == 'nested':
|
|
||||||
self.skipTest("ROCM does not support efficient attention on nested tensors, for now")
|
|
||||||
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
|
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
|
||||||
|
|
||||||
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
|
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
|
||||||
@ -2743,7 +2738,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
|
|
||||||
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
|
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
|
||||||
|
|
||||||
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
@parametrize("type", ["dense", "nested"])
|
@parametrize("type", ["dense", "nested"])
|
||||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||||
@ -2910,7 +2904,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
atol = 9e-4 if dtype == torch.float16 else 9e-3
|
atol = 9e-4 if dtype == torch.float16 else 9e-3
|
||||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
|
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
|
||||||
|
|
||||||
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
|
||||||
@parametrize("type", ["dense", "nested"])
|
@parametrize("type", ["dense", "nested"])
|
||||||
def test_fused_sdp_choice(self, device, type: str):
|
def test_fused_sdp_choice(self, device, type: str):
|
||||||
@ -3116,8 +3109,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
return
|
return
|
||||||
if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
|
if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
|
||||||
torch.cuda.empty_cache() # Prevent memory fragmentation
|
torch.cuda.empty_cache() # Prevent memory fragmentation
|
||||||
if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k:
|
|
||||||
self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k")
|
|
||||||
seed = 42
|
seed = 42
|
||||||
scale = scale if scale is None else (1 / head_dim)
|
scale = scale if scale is None else (1 / head_dim)
|
||||||
n_heads = 4
|
n_heads = 4
|
||||||
@ -3227,9 +3218,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
|
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
|
||||||
unittest.skip("Reference implementation OOM")
|
unittest.skip("Reference implementation OOM")
|
||||||
return
|
return
|
||||||
if TEST_WITH_ROCM and dtype == torch.float32:
|
|
||||||
unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.")
|
|
||||||
return
|
|
||||||
if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
|
if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
|
||||||
torch.cuda.empty_cache() # Prevent memory fragmentation
|
torch.cuda.empty_cache() # Prevent memory fragmentation
|
||||||
seed = 42
|
seed = 42
|
||||||
@ -3323,7 +3311,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
@parametrize("dropout_p", [0.0, 0.22, 0.48])
|
@parametrize("dropout_p", [0.0, 0.22, 0.48])
|
||||||
@parametrize("dtype", [torch.float16, torch.bfloat16])
|
@parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@parametrize("scale", [None, "l1"])
|
@parametrize("scale", [None, "l1"])
|
||||||
@parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False])
|
@parametrize("enable_gqa", [True, False])
|
||||||
@parametrize("n_heads", [[16, 8], [10, 2]])
|
@parametrize("n_heads", [[16, 8], [10, 2]])
|
||||||
@tf32_enabled()
|
@tf32_enabled()
|
||||||
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
|
||||||
@ -3488,8 +3476,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
|
|
||||||
if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
|
if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
|
||||||
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
|
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
|
||||||
if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k:
|
|
||||||
self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k")
|
|
||||||
|
|
||||||
seed = 42
|
seed = 42
|
||||||
n_heads = 4
|
n_heads = 4
|
||||||
@ -3593,7 +3579,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfRocm # Nested Tensor
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||||
@ -3627,7 +3612,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
|
|
||||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
|
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
@skipIfRocm # Nested tensor
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||||
@ -3653,6 +3637,9 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype)
|
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype)
|
||||||
batch, num_heads, head_dim = 32, 8, 64
|
batch, num_heads, head_dim = 32, 8, 64
|
||||||
head_dim_v = 32 if is_efficient else head_dim
|
head_dim_v = 32 if is_efficient else head_dim
|
||||||
|
if TEST_WITH_ROCM and head_dim != head_dim_v:
|
||||||
|
self.skipTest("head_dim != head_dim_v unsupported on ROCm for now")
|
||||||
|
return
|
||||||
seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item()
|
seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item()
|
||||||
if expand_q_batch
|
if expand_q_batch
|
||||||
else torch.randint(low=1, high=32, size=(batch,)).tolist())
|
else torch.randint(low=1, high=32, size=(batch,)).tolist())
|
||||||
@ -3716,7 +3703,7 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
|
|
||||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1.5e-3, rtol=1e-2)
|
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1.5e-3, rtol=1e-2)
|
||||||
|
|
||||||
@skipIfRocm # Nested tensor
|
@skipIfRocm(msg="Efficient Attention on ROCM does not support head_dim != head_dim_v for now.")
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||||
def test_fused_kernels_nested_broadcasting_query_dense(self, device):
|
def test_fused_kernels_nested_broadcasting_query_dense(self, device):
|
||||||
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
|
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
|
||||||
@ -3751,7 +3738,6 @@ class TestSDPACudaOnly(NNTestCase):
|
|||||||
|
|
||||||
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
|
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
@skipIfRocm # Nested tensor
|
|
||||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||||
@parametrize("batch_size", [8, 32])
|
@parametrize("batch_size", [8, 32])
|
||||||
@parametrize("max_seq_len_q", [32, 256])
|
@parametrize("max_seq_len_q", [32, 256])
|
||||||
@ -3919,7 +3905,6 @@ class TestAttnBias(NNTestCase):
|
|||||||
torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
|
torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
|
||||||
torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
|
torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
|
||||||
|
|
||||||
@skipIfRocm # No support for the second variant for now
|
|
||||||
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"shape",
|
"shape",
|
||||||
@ -3929,6 +3914,9 @@ class TestAttnBias(NNTestCase):
|
|||||||
make_tensor = partial(
|
make_tensor = partial(
|
||||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||||
)
|
)
|
||||||
|
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
|
||||||
|
self.skipTest("No support for LOWER_RIGHT variant for now")
|
||||||
|
return
|
||||||
|
|
||||||
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
|
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
|
||||||
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
|
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
|
||||||
@ -3952,7 +3940,6 @@ class TestAttnBias(NNTestCase):
|
|||||||
SDPBackend.CUDNN_ATTENTION]):
|
SDPBackend.CUDNN_ATTENTION]):
|
||||||
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
|
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
|
||||||
|
|
||||||
@skipIfRocm # CausalVariant
|
|
||||||
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
|
||||||
@parametrize(
|
@parametrize(
|
||||||
"shape",
|
"shape",
|
||||||
@ -3961,6 +3948,10 @@ class TestAttnBias(NNTestCase):
|
|||||||
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
|
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
|
||||||
@skipIfTorchDynamo("This function already calls torch.compile.")
|
@skipIfTorchDynamo("This function already calls torch.compile.")
|
||||||
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
|
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
|
||||||
|
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
|
||||||
|
self.skipTest("No support for LOWER_RIGHT variant for now")
|
||||||
|
return
|
||||||
|
|
||||||
cnts = CompileCounterWithBackend("aot_eager")
|
cnts = CompileCounterWithBackend("aot_eager")
|
||||||
make_tensor = partial(
|
make_tensor = partial(
|
||||||
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
torch.rand, device=device, dtype=torch.float16, requires_grad=True
|
||||||
@ -3988,7 +3979,6 @@ class TestAttnBias(NNTestCase):
|
|||||||
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
|
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
|
||||||
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
|
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
|
||||||
|
|
||||||
@skipIfRocm
|
|
||||||
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
|
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
|
||||||
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
|
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
|
||||||
make_tensor = partial(
|
make_tensor = partial(
|
||||||
|
@ -8762,11 +8762,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
|
|||||||
|
|
||||||
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
|
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
|
||||||
samples = []
|
samples = []
|
||||||
gqa_options = [False] if TEST_WITH_ROCM else [True, False] # TODO: GQA support
|
gqa_options = [True, False]
|
||||||
if TEST_WITH_ROCM and dtype == torch.float32:
|
causal_options = [True, False]
|
||||||
causal_options = [False] # FIXME: Large errors with causal+fp32
|
|
||||||
else:
|
|
||||||
causal_options = [True, False]
|
|
||||||
for qkv_shape, is_causal, dropout_p, _enable_gqa in product(
|
for qkv_shape, is_causal, dropout_p, _enable_gqa in product(
|
||||||
qkv_shapes, causal_options, [0.0, 0.5], gqa_options):
|
qkv_shapes, causal_options, [0.0, 0.5], gqa_options):
|
||||||
shape_q, shape_kv = qkv_shape
|
shape_q, shape_kv = qkv_shape
|
||||||
@ -8799,16 +8796,6 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
|
|||||||
dropout_p=0.0)
|
dropout_p=0.0)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not TEST_WITH_ROCM:
|
|
||||||
samples.append(
|
|
||||||
SampleInput(
|
|
||||||
make((batch, num_heads_q_gqa, seq_q, head_dim)),
|
|
||||||
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
|
|
||||||
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
|
|
||||||
enable_gqa=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield from samples
|
yield from samples
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user