[ROCm] Update to AOTriton 0.8b (#140172)

Notable new features for SDPA operators on AMD systems from AOTriton 0.8b:

1. Nestedtensor support;
2. MQA/GQA support;
3. Restore Efficient attention support for causal=True and seqlen_q != seqlen_k cases;
    + The kernel should use top-left alignment, bottom right alignment will be added later
4. Move gfx1100 (RX7900/W7800/W7900) out of experimental support status.
   However, users are strongly recommended to update to ROCM 6.2.4, notably for
   its firmware updates.

Related unit tests are enabled as well.

Notable related changes from AOTriton 0.8b:

1. AOTriton 0.8b moves the GPU kernel out of libaotriton.so to a separate directory `aotriton.images`;
2. LZMA replaces ZSTD as GPU kernel compression algorithm for better compression ratio: aotriton0.8b (.so + aotriton.images take 350MB) compared to aotriton0.7b .so: 800MB
3. The compression cannot be disabled now, and `liblzma` is hard run-time dependency.
    + Should not be a problem, since `lzma` is part of Python Standard Library

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140172
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily

Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com>
This commit is contained in:
Xinya Zhang
2024-12-06 21:45:18 +00:00
committed by PyTorch MergeBot
parent 0a619a212f
commit 424156c26c
14 changed files with 498 additions and 200 deletions

View File

@ -1,5 +1,5 @@
0.7b
manylinux_2_17
0.8b
manylinux_2_28
rocm6.2
9be04068c3c0857a4cfd17d7e39e71d0423ebac2
3e9e1959d23b93d78a08fcc5f868125dc3854dece32fd9458be9ef4467982291
6f8cbcac8a92775291bb1ba8f514d4beb350baf4
e938def5d32869fe2e00aec0300f354c9f157867bebdf2e104d732b94cb238d8

View File

@ -253,11 +253,11 @@ make_wheel_record() {
FPATH=$1
if echo $FPATH | grep RECORD >/dev/null 2>&1; then
# if the RECORD file, then
echo "$FPATH,,"
echo "\"$FPATH\",,"
else
HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g')
FSIZE=$(ls -nl $FPATH | awk '{print $5}')
echo "$FPATH,sha256=$HASH,$FSIZE"
echo "\"$FPATH\",sha256=$HASH,$FSIZE"
fi
}

View File

@ -225,11 +225,11 @@ make_wheel_record() {
FPATH=$1
if echo $FPATH | grep RECORD >/dev/null 2>&1; then
# if the RECORD file, then
echo "$FPATH,,"
echo "\"$FPATH\",,"
else
HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g')
FSIZE=$(ls -nl $FPATH | awk '{print $5}')
echo "$FPATH,sha256=$HASH,$FSIZE"
echo "\"$FPATH\",sha256=$HASH,$FSIZE"
fi
}

View File

@ -266,6 +266,20 @@ RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC))
DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/})
DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/})
# PyTorch 2.6+ (AOTriton 0.8b+)
# AKS = "AOTriton Kernel Storage", a file format to store GPU kernels compactly
if (( $(echo "${PYTORCH_VERSION} 2.6" | awk '{print ($1 >= $2)}') )); then
LIBAOTRITON_DIR=$(find "$ROCM_HOME/lib/" -name "libaotriton_v2.so" -printf '%h\n')
if [[ -z ${LIBAOTRITON_DIR} ]]; then
LIBAOTRITON_DIR=$(find "$ROCM_HOME/" -name "libaotriton_v2.so" -printf '%h\n')
fi
AKS_FILES=($(find "${LIBAOTRITON_DIR}/aotriton.images" -type f -name '*.aks?' -printf '%P\n'))
AKS_SRC="${LIBAOTRITON_DIR}/aotriton.images"
AKS_DST="lib/aotriton.images"
DEPS_AUX_SRCLIST+=(${AKS_FILES[@]/#/${AKS_SRC}/})
DEPS_AUX_DSTLIST+=(${AKS_FILES[@]/#/${AKS_DST}/})
fi
echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}"
SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )"

View File

@ -1160,6 +1160,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
using aotriton::v2::flash::attn_fwd;
using aotriton::v2::flash::attn_fwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
@ -1172,6 +1173,29 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
hipError_t err; // TODO: Error handling
if (seqstart_q.has_value()) {
// varlen aka nested tensor
err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
softmax_scale,
mk_aotensor<2>(softmax_lse, "M"),
mk_aotensor(output_t, "Out"),
dropout_p,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
} else {
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
@ -1188,6 +1212,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
}
if (!compute_logsumexp) {
// Set the tensor to empty when compute_logsumexp is false
logsumexp = at::empty(

View File

@ -441,10 +441,37 @@ _efficient_attention_backward(
hipError_t err;
using aotriton::v2::flash::attn_bwd;
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype()));
if (cu_seqlens_q.has_value()) {
// varlen aka Nested tensor
err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_t, "dq"),
mk_aotensor(dk_t, "dk"),
mk_aotensor(dv_t, "dv"),
bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4,
mk_aotensor<2>(softmax_lse, "L"),
mk_aotensor<2>(delta, "delta"),
float(dropout_p),
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
} else {
err = attn_bwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
@ -464,6 +491,7 @@ _efficient_attention_backward(
0,
is_causal,
stream);
}
#else
at::Tensor workspace;
cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());

View File

@ -103,6 +103,10 @@ int64_t minimum_gemm_alignment(sdp_params const& params) {
return matmul_alignment_mn;
}
// On ROCM, ME and FA share the backend, and hence they share the checking
// function for fundamental limitations by the GPU kernel
// caller_is_meff is added to make the TORCH_WARN message showing the correct result
template<bool caller_is_meff = false>
bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
// All head_dim sizes must be equal and less than 256
const auto max_size = c10::SymInt(256);
@ -114,7 +118,8 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
if (!(same_head_dim_size && (query_size_last <= max_size))) {
if (debug) {
TORCH_WARN(
"Flash attention requires q,k,v to have the same last dimension and to be less than or equal to 256.",
caller_is_meff ? "Efficient attention on ROCM" : "Flash attention",
" requires q,k,v to have the same last dimension and to be less than or equal to 256.",
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
@ -128,6 +133,8 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
return true;
}
// See check_head_dim_size_flash above for the purpose of caller_is_meff
template<bool caller_is_meff = false>
bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) {
const auto max_size = c10::SymInt(256);
const auto query_size_last = params.query.sym_size(-1);
@ -139,7 +146,9 @@ bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) {
(query_size_last <= max_size))) {
if (debug) {
TORCH_WARN(
"For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
"For NestedTensor inputs,",
caller_is_meff ? " Efficient attention on ROCM " : " Flash attention",
" requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
" Got Query.size(-1): ",
query_size_last,
", Key.size(-1): ",
@ -208,7 +217,6 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
// Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
@ -220,19 +228,11 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
return false;
}
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_navi3x) {
TORCH_WARN_ONCE("Flash attention support on Navi31 GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
@ -252,7 +252,6 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
@ -264,19 +263,11 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
return false;
}
c10::string_view arch(dprops->gcnArchName);
if (arch == "gfx1100") {
static const bool enable_navi3x = c10::utils::check_env("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL") == true;
if (!enable_navi3x) {
TORCH_WARN_ONCE("Memory Efficient attention on Navi31 GPU is still experimental."
" Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1.");
return false;
}
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm50, sm90>(dprops)) {
if (debug) {
TORCH_WARN(
@ -615,7 +606,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
check_all_tensors_on_device,
check_tensor_shapes,
check_for_attn_mask,
check_head_dim_size_flash,
check_head_dim_size_flash<false /*caller_is_meff*/>,
check_flash_attention_hardware_support,
check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
check_flash_causal_non_square_seqlens,
@ -629,7 +620,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
if (has_for_nested_inputs(params)) {
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_nested,
check_head_dim_size_flash_nested,
check_head_dim_size_flash_nested<false /*caller_is_meff*/>,
check_for_seq_len_0_nested_tensor);
for (auto& constraint : nested_constraints) {
if (!constraint(params, debug)) {
@ -637,11 +628,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
}
}
}
#if USE_ROCM
constexpr bool backend_supports_grouped_query_attention = false;
#else
constexpr bool backend_supports_grouped_query_attention = true;
#endif
if (has_only_dense_inputs(params)) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense<backend_supports_grouped_query_attention>,
@ -679,7 +666,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
check_mem_efficient_hardware_support,
check_tensor_shapes,
#ifdef USE_ROCM
check_head_dim_size_flash
check_head_dim_size_flash<true /* caller_is_meff */>
#else
check_head_dim_size_mem_efficient
#endif
@ -691,12 +678,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}
if (has_for_nested_inputs(params)) {
#ifdef USE_ROCM
TORCH_WARN_ONCE(false, "[ROCM] no support for nested tensors in memory efficient attention.");
return false;
#endif
constexpr auto nested_constraints = array_of<bool (*)(sdp_params const&, bool)>(
#ifndef USE_ROCM // ME and FA shares backend on ROCM and thus supports training
check_requires_grad_and_nested,
#else // Meanwhile ME on ROCM share the limits of FA about head dimensions
check_head_dim_size_flash_nested<true /* caller_is_meff */>,
#endif
check_batch_size_nested,
check_for_seq_len_0_nested_tensor);
for (auto& constraint : nested_constraints) {

View File

@ -77,6 +77,37 @@ void check_gpu_arch(hipStream_t stream) {
}
}
// We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will
// be used in the backward function
std::tuple<at::Tensor, at::Tensor, at::PhiloxCudaState, bool>
prepare_philox_arguments(float p_dropout, int64_t counter_offset) {
at::Tensor seed_t, offset_t;
at::PhiloxCudaState philox_state;
bool use_philox_state = false;
if (p_dropout <= 0.0) {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
return { seed_t, offset_t, philox_state, use_philox_state };
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_state = gen->philox_cuda_state(counter_offset);
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong).device(at::kCUDA));
} else {
// See Note [CUDA Graph-safe RNG states] about the design
use_philox_state = true;
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
return { seed_t, offset_t, philox_state, use_philox_state };
}
}
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
@ -158,44 +189,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = round_multiple(head_size_og, 8);
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
// We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will
// be used in the backward function
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::Tensor seed_t, offset_t;
at::PhiloxCudaState philox_state;
bool use_philox_state = false;
if (p_dropout > 0.0) {
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = batch_size * num_heads * 32;
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
philox_state = gen->philox_cuda_state(counter_offset);
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong).device(at::kCUDA));
} else {
// See Note [CUDA Graph-safe RNG states] about the design
use_philox_state = true;
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
} else {
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
} else {
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
}
}
auto [seed_t, offset_t, philox_state, use_philox_state] =
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
// Transpose tensors to meet AOTriton's Flash API
at::Tensor q_t = q_padded.permute({0,2,1,3});
@ -215,7 +211,6 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_fwd;
using aotriton::TensorView;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
@ -266,16 +261,150 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
int window_size_right,
const bool return_softmax,
std::optional<at::Generator> gen_) {
TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt");
const bool paged_KV = block_table_.has_value();
TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt");
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
TORCH_CHECK(false, "mha_varlen_fwd not supported on ROCm");
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
at::Tensor softmax_lse = at::empty({}, at::dtype(at::kFloat));
at::Tensor p = at::empty({}, at::dtype(at::kFloat));
at::Tensor offset_t = at::empty({}, at::dtype(at::kLong));
at::Tensor seed_t = at::empty({}, at::dtype(at::kLong));
at::Tensor out = at::empty({}, at::dtype(at::kFloat));
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
return {out, q, k, v, softmax_lse, seed_t, offset_t, p};
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(cu_seqlens_q);
CHECK_DEVICE(cu_seqlens_k);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) {
is_causal = false;
} // causal=true is the same as causal=false in this case
at::Tensor temp_q = q;
const int total_q = temp_q.sizes()[0];
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) {
window_size_left = -1;
}
if (window_size_right >= max_seqlen_k) {
window_size_right = -1;
}
CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
// AOTriton's varlen API needs input shapes be
// (1, num_heads, total sequence lenght, head dimension)
at::Tensor q_padded, k_padded, v_padded;
at::Tensor out, out_padded;
q_padded = q.unsqueeze(0).transpose(1, 2);
k_padded = k.unsqueeze(0).transpose(1, 2);
v_padded = v.unsqueeze(0).transpose(1, 2);
if (out_.has_value()) {
out = out_.value();
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
} else {
out = at::empty_like(q);
}
out_padded = out.unsqueeze(0).transpose(1, 2);
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size = head_size_og;
auto opts = q.options();
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
at::Tensor M = softmax_lse.view({batch_size * num_heads, max_seqlen_q});
at::Tensor softmax_fa_t;
// Only return softmax if there's dropout to reduce compilation time
if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
softmax_fa_t = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts);
} else {
softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts);
}
if (zero_tensors) {
out.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (return_softmax) {
softmax_fa_t.zero_();
}
}
auto [seed_t, offset_t, philox_state, use_philox_state] =
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
if (max_seqlen_k > 0) {
hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_fwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
auto offset2 = use_philox_state ? philox_state.offset_intragraph_ : 0;
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),
mk_aotensor(v_padded, "v"),
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
empty_bias,
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(out_padded, "Out"),
p_dropout,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
stream);
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
softmax_lse.fill_(std::numeric_limits<float>::infinity());
}
return {out, q, k, v, softmax_lse, seed_t, offset_t, softmax_fa_t};
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -297,6 +426,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
@ -341,8 +473,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
@ -381,23 +511,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
dv = at::empty_like(k);
}
// const at::Tensor& dout_padded = dout;
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
at::Tensor dk_expanded, dv_expanded;
if (num_heads_k != num_heads) { // MQA / GQA
dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
} else {
dk_expanded = dk;
dv_expanded = dv;
}
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor q_t = q.permute({0,2,1,3});
at::Tensor k_t = k.permute({0,2,1,3});
@ -440,14 +555,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
stream);
}
// For MQA/GQA we need to sum dK and dV across the groups
if (num_heads_k != num_heads) {
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
}
return { dq, dk, dv, softmax_d };
#undef CALL_BWD_DROPOUT
#undef CALL_BWD
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -473,13 +581,173 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
int window_size_right,
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm");
const at::Tensor philox_offset)
{
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
at::Tensor softmax_d = at::empty({}, at::dtype(at::kFloat));
return { q, k, v, softmax_d };
if (is_causal) {
window_size_right = 0;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
bool is_dropout = p_dropout > 0.0;
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);
const auto sizes = q.sizes();
const int total_q = sizes[0];
const int batch_size = cu_seqlens_q.numel() - 1;
const int num_heads = sizes[1];
const int head_size_og = dout.size(2);
const int head_size = sizes[2];
const int total_k = k.size(0);
const int num_heads_k = k.size(1);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
CHECK_SHAPE(out, total_q, num_heads, head_size);
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous();
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
at::Tensor q_padded, k_padded, v_padded;
q_padded = q.unsqueeze(0).transpose(1, 2);
k_padded = k.unsqueeze(0).transpose(1, 2);
v_padded = v.unsqueeze(0).transpose(1, 2);
at::Tensor out_t, dout_t;
out_t = out.unsqueeze(0).transpose(1, 2);
dout_t = dout.unsqueeze(0).transpose(1, 2);
at::Tensor dq, dk, dv;
at::Tensor dq_padded, dk_padded, dv_padded;
if (dq_.has_value()) {
dq = dq_.value();
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
CHECK_DEVICE(dq);
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
CHECK_SHAPE(dq, total_q, num_heads, head_size);
} else {
dq = at::empty_like(q);
}
if (dk_.has_value()) {
dk = dk_.value();
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
CHECK_DEVICE(dk);
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
} else {
dk = at::empty_like(k);
}
if (dv_.has_value()) {
dv = dv_.value();
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
CHECK_DEVICE(dv);
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
} else {
dv = at::empty_like(v);
}
dq_padded = dq.unsqueeze(0).transpose(1, 2);
dk_padded = dk.unsqueeze(0).transpose(1, 2);
dv_padded = dv.unsqueeze(0).transpose(1, 2);
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
if( zero_tensors ) {
dq.zero_();
dk.zero_();
dv.zero_();
softmax_d.zero_();
}
at::PhiloxCudaState philox_args;
if (is_dropout) {
if (at::cuda::currentStreamCaptureStatus() ==
at::cuda::CaptureStatus::None)
{
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
} else { // dropout + capture
philox_args = at::PhiloxCudaState(
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
}
}
if (max_seqlen_q > 0) {
hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),
mk_aotensor(v_padded, "v"),
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
empty_bias,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_padded, "dq"),
mk_aotensor(dk_padded, "dk"),
mk_aotensor(dv_padded, "dv"),
empty_bias,
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dq.zero_();
dk.zero_();
dv.zero_();
softmax_d.zero_();
}
return { dq, dk, dv, softmax_d };
}
} // namespace pytorch_fmha
#endif

View File

@ -1373,6 +1373,13 @@ def main():
"lib/*.lib",
]
)
aotriton_image_path = os.path.join(lib_path, "aotriton.images")
aks2_files = []
for root, dirs, files in os.walk(aotriton_image_path):
subpath = os.path.relpath(root, start=aotriton_image_path)
for fn in files:
aks2_files.append(os.path.join("lib/aotriton.images", subpath, fn))
torch_package_data += aks2_files
if get_cmake_cache_vars()["USE_TENSORPIPE"]:
torch_package_data.extend(
[

View File

@ -389,8 +389,6 @@ class TestFlexAttention(InductorTestCase):
KV_S = Q_S
if V_D is None:
V_D = Q_D
if TEST_WITH_ROCM and Q_H != KV_H:
self.skipTest("enable_gqa=True is unsupported on ROCM, for now")
q = torch.randn(
(Q_B, Q_H, Q_S, Q_D), dtype=dtype, device="cuda", requires_grad=True
)
@ -565,9 +563,6 @@ class TestFlexAttention(InductorTestCase):
V_D: int = D,
block_mask: Optional[BlockMask] = None,
):
if TEST_WITH_ROCM and Q_H != KV_H:
self.skipTest("enable_gqa=True is unsupported on ROCM, for now")
assert Q_H % KV_H == 0
q = torch.randn(

View File

@ -22,7 +22,7 @@ from torch.nn.attention.flex_attention import (
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16
from torch.testing._internal.common_utils import skipIfRocm, TEST_WITH_ROCM
from torch.testing._internal.common_utils import skipIfRocm
from torch.utils._triton import has_triton
@ -492,9 +492,6 @@ class TestFlexDecoding(InductorTestCase):
V_D: int = D,
block_mask: Optional[BlockMask] = None,
):
if TEST_WITH_ROCM and Q_H != KV_H:
self.skipTest("enable_gqa=True is unsupported on ROCM, for now")
assert Q_H % KV_H == 0
q = torch.randn(

View File

@ -277,8 +277,8 @@ class TestMHADeviceType(TestCase):
def test_native_multihead_self_attention(self, device, dtype, use_nt,
need_weights, average_attn_weights, use_padding, pad_all, fused):
if TEST_WITH_ROCM:
if use_nt:
self.skipTest("ROCM does not support nested tensors for Flash Attention for now.")
if use_nt and use_padding and pad_all:
self.skipTest("Large numerical errors on ROCM to investigate.")
if use_padding and not pad_all and fused:
self.skipTest("Large numerical errors on ROCM to investigate.")
for need_weights in (False, not pad_all):

View File

@ -1630,7 +1630,6 @@ class TestSDPAFailureModes(NNTestCase):
q, k, v, None, 0.0, False))
@onlyCUDA
@skipIfRocm(msg='enable_gqa=True unsupported')
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("fused_kernel", [SDPBackend.EFFICIENT_ATTENTION])
def test_invalid_sdpa_kernel_grouped_query_attention_cuda(self, device, fused_kernel):
@ -1646,7 +1645,6 @@ class TestSDPAFailureModes(NNTestCase):
is_causal=False, enable_gqa=True)
@onlyCPU
@skipIfRocm(msg='enable_gqa=True unsupported')
def test_invalid_sdpa_kernel_grouped_query_attention_cpu(self, device):
rand_query = torch.rand(8, 8, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
rand_key = torch.rand(8, 4, 64, 64, device=device, dtype=torch.float16, requires_grad=True)
@ -1784,7 +1782,6 @@ class TestSDPAFailureModes(NNTestCase):
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
@onlyCUDA
@skipIfRocm # Missing EFFICIENT_ATTENTION
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
def test_fused_kernels_nested_broadcasting_error_cases(self, device):
# one of k,v needs to be broadcasted and other has non consistent seq_len dim
@ -2485,7 +2482,7 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
@skipIfRocm # No cuDNN Attention
@skipIfRocm(msg="No cuDNN on ROCm")
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_fused_attention_different_dk_dv(self, device):
dtype = torch.bfloat16
@ -2526,7 +2523,7 @@ class TestSDPACudaOnly(NNTestCase):
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
o = torch.nn.functional.scaled_dot_product_attention(q, k, v)
@skipIfRocm # No cuDNN Attention
@skipIfRocm(msg="No cuDNN on ROCm")
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_trivial_output_transpose(self, device):
# see also: https://github.com/pytorch/pytorch/issues/134001
@ -2713,8 +2710,6 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("type", ["dense", "nested"])
@parametrize("is_contiguous", [True, False])
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
if TEST_WITH_ROCM and type == 'nested':
self.skipTest("ROCM does not support efficient attention on nested tensors, for now")
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True)
batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64
@ -2743,7 +2738,6 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("type", ["dense", "nested"])
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
@ -2910,7 +2904,6 @@ class TestSDPACudaOnly(NNTestCase):
atol = 9e-4 if dtype == torch.float16 else 9e-3
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
@skipIfRocm # Missing nested and EFFICIENT_ATTENTION
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
@parametrize("type", ["dense", "nested"])
def test_fused_sdp_choice(self, device, type: str):
@ -3116,8 +3109,6 @@ class TestSDPACudaOnly(NNTestCase):
return
if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
torch.cuda.empty_cache() # Prevent memory fragmentation
if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k:
self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k")
seed = 42
scale = scale if scale is None else (1 / head_dim)
n_heads = 4
@ -3227,9 +3218,6 @@ class TestSDPACudaOnly(NNTestCase):
if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30:
unittest.skip("Reference implementation OOM")
return
if TEST_WITH_ROCM and dtype == torch.float32:
unittest.skip("Skip fp32 attn_mask gradients on ROCM, for now.")
return
if TEST_WITH_ROCM and seq_len_q * seq_len_k * head_dim * batch_size > 1024 * 1024 * 128:
torch.cuda.empty_cache() # Prevent memory fragmentation
seed = 42
@ -3323,7 +3311,7 @@ class TestSDPACudaOnly(NNTestCase):
@parametrize("dropout_p", [0.0, 0.22, 0.48])
@parametrize("dtype", [torch.float16, torch.bfloat16])
@parametrize("scale", [None, "l1"])
@parametrize("enable_gqa", [True, False] if not TEST_WITH_ROCM else [False])
@parametrize("enable_gqa", [True, False])
@parametrize("n_heads", [[16, 8], [10, 2]])
@tf32_enabled()
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
@ -3488,8 +3476,6 @@ class TestSDPACudaOnly(NNTestCase):
if fused_kernel == SDPBackend.FLASH_ATTENTION and is_causal and seq_len_q != seq_len_k:
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")
if TEST_WITH_ROCM and is_causal and seq_len_q != seq_len_k:
self.skipTest("ROCm does not accept is_casual when seq_len_q != seq_len_k")
seed = 42
n_heads = 4
@ -3593,7 +3579,6 @@ class TestSDPACudaOnly(NNTestCase):
}
)
@skipIfRocm # Nested Tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
@ -3627,7 +3612,6 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
@skipIfRocm # Nested tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
@ -3653,6 +3637,9 @@ class TestSDPACudaOnly(NNTestCase):
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype)
batch, num_heads, head_dim = 32, 8, 64
head_dim_v = 32 if is_efficient else head_dim
if TEST_WITH_ROCM and head_dim != head_dim_v:
self.skipTest("head_dim != head_dim_v unsupported on ROCm for now")
return
seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item()
if expand_q_batch
else torch.randint(low=1, high=32, size=(batch,)).tolist())
@ -3716,7 +3703,7 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1.5e-3, rtol=1e-2)
@skipIfRocm # Nested tensor
@skipIfRocm(msg="Efficient Attention on ROCM does not support head_dim != head_dim_v for now.")
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
def test_fused_kernels_nested_broadcasting_query_dense(self, device):
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
@ -3751,7 +3738,6 @@ class TestSDPACudaOnly(NNTestCase):
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=1e-3, rtol=1e-2)
@skipIfRocm # Nested tensor
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
@parametrize("batch_size", [8, 32])
@parametrize("max_seq_len_q", [32, 256])
@ -3919,7 +3905,6 @@ class TestAttnBias(NNTestCase):
torch.testing.assert_close(key.grad, key_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
torch.testing.assert_close(value.grad, value_prototype.grad, rtol=grad_tolerances.rtol, atol=grad_tolerances.atol)
@skipIfRocm # No support for the second variant for now
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
@parametrize(
"shape",
@ -3929,6 +3914,9 @@ class TestAttnBias(NNTestCase):
make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True
)
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
self.skipTest("No support for LOWER_RIGHT variant for now")
return
bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
@ -3952,7 +3940,6 @@ class TestAttnBias(NNTestCase):
SDPBackend.CUDNN_ATTENTION]):
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=None)
@skipIfRocm # CausalVariant
@parametrize("causal_variant", [CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT])
@parametrize(
"shape",
@ -3961,6 +3948,10 @@ class TestAttnBias(NNTestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows")
@skipIfTorchDynamo("This function already calls torch.compile.")
def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]):
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
self.skipTest("No support for LOWER_RIGHT variant for now")
return
cnts = CompileCounterWithBackend("aot_eager")
make_tensor = partial(
torch.rand, device=device, dtype=torch.float16, requires_grad=True
@ -3988,7 +3979,6 @@ class TestAttnBias(NNTestCase):
self.run_test(device, make_q_tensor, make_kv_tensor, attn_bias, forw_tol, grad_tol, backend=cnts)
self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!")
@skipIfRocm
@parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)])
def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]):
make_tensor = partial(

View File

@ -8762,10 +8762,7 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
qkv_shapes = [(dim_3_q_shape, dim_3_kv_shape), (dim_4_q_shape, dim_4_kv_shape), broadcast_tuple]
samples = []
gqa_options = [False] if TEST_WITH_ROCM else [True, False] # TODO: GQA support
if TEST_WITH_ROCM and dtype == torch.float32:
causal_options = [False] # FIXME: Large errors with causal+fp32
else:
gqa_options = [True, False]
causal_options = [True, False]
for qkv_shape, is_causal, dropout_p, _enable_gqa in product(
qkv_shapes, causal_options, [0.0, 0.5], gqa_options):
@ -8799,16 +8796,6 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_
dropout_p=0.0)
)
if not TEST_WITH_ROCM:
samples.append(
SampleInput(
make((batch, num_heads_q_gqa, seq_q, head_dim)),
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
make((batch, num_heads_kv_gqa, seq_kv, head_dim)),
enable_gqa=True
)
)
yield from samples