[ROCm] Bump AOTriton to 0.10b (#156290)

Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.10b:

* Official support of gfx950/gfx1201
* Experimental support of gfx1101/gfx1151/gfx1150/gfx1200
* Reduce libaotriton.so binary size by over 80%.
  + Without this optimization the binary size of `libaotriton.so` could be
    over 100MiB due to 2x more supported architectures compared with 0.9b.
    Now it is only about 11MiB.
* Support sliding window attention (SWA) in
  `_flash_attention_forward/backward`. Should fix #154582

See https://github.com/ROCm/aotriton/releases/tag/0.10b for full details,
including Known Problems.

Notable changes to SDPA backend:

* `std::optional<int64_t>` `window_size_left/right` are directly passed to
  ROCM's SDPA backend, because the default value `-1` is meaningful to
  AOTriton's backend and bottom-right aligned causal mask is implemented with
  negative `window_size_left/right`
* Some code clean up around `USE_CK_FLASH_ATTENTION`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156290
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
This commit is contained in:
Xinya Zhang
2025-06-19 21:13:53 +00:00
committed by PyTorch MergeBot
parent 3644b41a7c
commit 34d8e64ef6
7 changed files with 368 additions and 241 deletions

View File

@ -1113,8 +1113,10 @@ _flash_attention_forward(
std::optional<Tensor> alibi_slopes = _alibi_slopes;
const float softcap = 0.0;
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
#endif
// We are going to have two paths:
// 1. The standard MHA path for dense tensors
@ -1151,8 +1153,13 @@ _flash_attention_forward(
softmax_scale,
false /*zero_tensors*/,
is_causal,
#ifdef USE_ROCM
window_size_left,
window_size_right,
#else
non_null_window_left,
non_null_window_right,
#endif
softcap,
return_debug_mask,
std::nullopt /*gen_*/);
@ -1175,8 +1182,13 @@ _flash_attention_forward(
dropout_p,
softmax_scale,
is_causal,
#ifdef USE_ROCM
window_size_left,
window_size_right,
#else
non_null_window_left,
non_null_window_right,
#endif
softcap,
return_debug_mask, /*return_softmax (this is used for testing)*/
std::nullopt);

View File

@ -87,8 +87,10 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
auto contiguous_grad_out = grad_out.contiguous();
auto contiguous_out = out.contiguous();
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
#endif
std::optional<at::Tensor> dq{std::nullopt};
std::optional<at::Tensor> dk{std::nullopt};
@ -136,8 +138,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
softmax_scale,
false /*zero_tensors*/,
is_causal,
#ifdef USE_ROCM
window_size_left,
window_size_right,
#else
non_null_window_left,
non_null_window_right,
#endif
softcap,
determinisitic,
philox_seed,
@ -159,8 +166,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
dropout_p,
softmax_scale,
is_causal,
#ifdef USE_ROCM
window_size_left,
window_size_right,
#else
non_null_window_left,
non_null_window_right,
#endif
softcap,
determinisitic,
philox_seed,

View File

@ -64,8 +64,14 @@
#include <aotriton/flash.h>
#include <aotriton/runtime.h>
#if AOTRITON_VERSION_MINOR != 9
#error "This adaptor code is only tested with AOTriton 0.9.x"
#if AOTRITON_VERSION_MINOR < 9
#error "This adaptor code is only tested with AOTriton >= 0.9"
#endif
#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10
#define V3_API 1
#else
#define V3_API 0
#endif
namespace pytorch_flash {
@ -81,6 +87,38 @@ void check_gpu_arch(hipStream_t stream) {
}
}
std::tuple<bool, int, int>
calculate_swa(std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
int max_seqlen_q,
int max_seqlen_k,
bool is_causal) {
#if V3_API // SWA is exposed through V3 API
bool needs_swa = false;
using aotriton::v3::flash::WindowValue;
// Default values when std::optional window_size_left/right have no value
int window_left = max_seqlen_q;
int window_right = max_seqlen_k;
if (is_causal) {
window_left = WindowValue::TopLeftAligned;
window_right = WindowValue::TopLeftAligned;
}
if (window_size_left.has_value() || window_size_right.has_value()) {
needs_swa = true;
window_left = window_size_left.value_or(window_left);
window_right = window_size_right.value_or(window_right);
}
return std::make_tuple(needs_swa, window_left, window_right);
#else
if (window_size_left.has_value() || window_size_right.has_value()) {
TORCH_WARN_ONCE("Current AOTriton does not support sliding window attention (SWA)."
" Both window_size_left and window_size_right will be ignored."
" Re-compile PyTorch with AOTriton >= 0.10b to enable SWA support.");
}
return std::make_tuple(false, 0, 0);
#endif
}
// We want to checkpoint and save the RNG state for backward if dropout
// We get the default generator and return the seed and offset which will
// be used in the backward function
@ -127,8 +165,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const bool return_softmax,
const std::optional<at::Generator>& gen_) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
@ -161,7 +199,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
if (is_causal) { window_size_right = 0; }
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
@ -212,6 +249,19 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
atomic_counter = at::zeros({1}, opts.dtype(at::kInt));
}
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
window_size_right,
seqlen_q,
seqlen_k,
is_causal);
#if V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_fwd;
using sdp::aotriton_adapter::mk_aotensor;
@ -226,23 +276,54 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
empty_bias,
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(output_t, "Out"),
p_dropout,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
if (uses_swa) {
#if V3_API
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_fwd_params params;
params.Q = mk_aotensor(q_t, "q");
params.K = mk_aotensor(k_t, "k");
params.V = mk_aotensor(v_t, "v");
params.Sm_scale = softmax_scale;
params.L = mk_aotensor<2>(M, "M");
params.Out = mk_aotensor(output_t, "Out");
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = p_dropout;
params.philox_seed_ptr = seed;
params.philox_offset1 = offset1;
params.philox_offset2 = offset2;
params.philox_seed_output = seed_output;
params.philox_offset_output = offset_output;
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
params.persistent_atomic_counter = persistent_counter;
params.causal_type = CausalType::WindowedAttention;
params.varlen_type = VarlenType::None;
params.window_left = window_left;
params.window_right = window_right;
err = aotriton::v3::flash::attn_fwd(params,
aotriton::v3::flash::attn_fwd_params::kVersion,
stream);
#endif
} else {
err = attn_fwd(mk_aotensor(q_t, "q"),
mk_aotensor(k_t, "k"),
mk_aotensor(v_t, "v"),
empty_bias,
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(output_t, "Out"),
p_dropout,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
}
return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t};
}
@ -263,8 +344,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const bool return_softmax,
const std::optional<at::Generator>& gen_) {
TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt");
@ -312,13 +393,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
TORCH_CHECK(head_size_og <= 512, "FlashAttention on ROCm forward only supports head dimension at most 512");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) {
window_size_left = -1;
}
if (window_size_right >= max_seqlen_k) {
window_size_right = -1;
}
CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
const int total_k = k.size(0);
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
@ -368,6 +442,19 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
}
}
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
window_size_right,
max_seqlen_q,
max_seqlen_k,
is_causal);
#if V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
auto [seed_t, offset_t, philox_state, use_philox_state] =
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
@ -390,27 +477,58 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),
mk_aotensor(v_padded, "v"),
empty_bias,
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(out_padded, "Out"),
p_dropout,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
if (uses_swa) {
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_fwd_params params;
params.Q = mk_aotensor(q_padded, "q");
params.K = mk_aotensor(k_padded, "k");
params.V = mk_aotensor(v_padded, "v");
params.Sm_scale = softmax_scale;
params.L = mk_aotensor<2>(M, "M");
params.Out = mk_aotensor(out_padded, "Out");
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q");
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k");
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = p_dropout;
params.philox_seed_ptr = seed;
params.philox_offset1 = offset1;
params.philox_offset2 = offset2;
params.philox_seed_output = seed_output;
params.philox_offset_output = offset_output;
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
params.persistent_atomic_counter = persistent_counter;
params.causal_type = CausalType::WindowedAttention;
params.varlen_type = VarlenType::CompactVarlen;
params.window_left = window_left;
params.window_right = window_right;
err = aotriton::v3::flash::attn_fwd(params,
aotriton::v3::flash::attn_fwd_params::kVersion,
stream);
} else {
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),
mk_aotensor(v_padded, "v"),
empty_bias,
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
softmax_scale,
mk_aotensor<2>(M, "M"),
mk_aotensor(out_padded, "Out"),
p_dropout,
seed,
offset1,
offset2,
seed_output,
offset_output,
mk_aotensor(softmax_fa_t, "encoded_softmax"),
is_causal,
persistent_counter,
stream);
}
} else {
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
out.zero_();
@ -434,8 +552,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const bool deterministic,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset) {
@ -524,6 +642,19 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
dv = at::empty_like(k);
}
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
window_size_right,
seqlen_q,
seqlen_k,
is_causal);
#if V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
auto opts = q.options();
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
@ -541,10 +672,40 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
int d_head = head_size_og;
bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512;
hipError_t err; // TODO: Error handling
if (use_fused_bwd) {
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
if (uses_swa) {
// Fused BWD does not support SWA
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_bwd_params params;
params.Q = mk_aotensor(q_t, "q");
params.K = mk_aotensor(k_t, "k");
params.V = mk_aotensor(v_t, "v");
params.Sm_scale = softmax_scale;
params.Out = mk_aotensor(out_t, "out");
params.DO = mk_aotensor(dout_t, "dout");
params.DK = mk_aotensor(dq_t, "dq");
params.DV = mk_aotensor(dk_t, "dk");
params.DQ = mk_aotensor(dv_t, "dv");
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
params.D = mk_aotensor<2>(delta, "delta");
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = p_dropout;
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
params.philox_offset1 = mk_aoscalartensor(philox_offset);
params.philox_offset2 = 0;
params.causal_type = CausalType::WindowedAttention;
params.varlen_type = VarlenType::None;
params.window_left = window_left;
params.window_right = window_right;
err = aotriton::v3::flash::attn_bwd(params,
aotriton::v3::flash::attn_bwd_params::kVersion,
stream);
} else if (use_fused_bwd) {
using aotriton::v2::flash::attn_bwd_fused;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
@ -568,8 +729,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
} else {
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
using aotriton::v2::flash::attn_bwd;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd(mk_aotensor(q_t, "q"),
@ -615,17 +774,14 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const bool deterministic,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset)
{
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
if (is_causal) {
window_size_right = 0;
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
@ -669,9 +825,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(head_size <= 512, "FlashAttention on ROCm backward only supports head dimension at most 512");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
@ -734,6 +887,19 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
softmax_d.zero_();
}
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
window_size_right,
max_seqlen_q,
max_seqlen_k,
is_causal);
#if V3_API
const bool uses_swa = needs_swa;
#else
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
// optimized out (hopefully).
constexpr bool uses_swa = false;
#endif
at::PhiloxCudaState philox_args;
if (is_dropout) {
if (at::cuda::currentStreamCaptureStatus() ==
@ -747,34 +913,66 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
}
if (max_seqlen_q > 0) {
hipError_t err; // TODO: Error handling
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),
mk_aotensor(v_padded, "v"),
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
empty_bias,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_padded, "dq"),
mk_aotensor(dk_padded, "dk"),
mk_aotensor(dv_padded, "dv"),
empty_bias,
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
if (uses_swa) {
using aotriton::v3::flash::CausalType;
using aotriton::v3::flash::VarlenType;
aotriton::v3::flash::attn_bwd_params params;
params.Q = mk_aotensor(q_padded, "q");
params.K = mk_aotensor(k_padded, "k");
params.V = mk_aotensor(v_padded, "v");
params.Sm_scale = softmax_scale;
params.Out = mk_aotensor(out_t, "out");
params.DO = mk_aotensor(dout_t, "dout");
params.DK = mk_aotensor(dq_padded, "dq");
params.DV = mk_aotensor(dk_padded, "dk");
params.DQ = mk_aotensor(dv_padded, "dv");
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
params.D = mk_aotensor<2>(delta, "delta");
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q");
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k");
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
params.dropout_p = p_dropout;
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
params.philox_offset1 = mk_aoscalartensor(philox_offset);
params.philox_offset2 = 0;
params.causal_type = CausalType::WindowedAttention;
params.varlen_type = VarlenType::CompactVarlen;
params.window_left = window_left;
params.window_right = window_right;
err = aotriton::v3::flash::attn_bwd(params,
aotriton::v3::flash::attn_bwd_params::kVersion,
stream);
} else {
using aotriton::v2::flash::attn_bwd_compact_varlen;
using sdp::aotriton_adapter::cast_dtype;
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
mk_aotensor(k_padded, "k"),
mk_aotensor(v_padded, "v"),
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
max_seqlen_q,
max_seqlen_k,
empty_bias,
softmax_scale,
mk_aotensor(out_t, "out"),
mk_aotensor(dout_t, "dout"),
mk_aotensor(dq_padded, "dq"),
mk_aotensor(dk_padded, "dk"),
mk_aotensor(dv_padded, "dv"),
empty_bias,
mk_aotensor<2>(softmax_lse_cont, "L"),
mk_aotensor<2>(delta, "delta"),
p_dropout,
mk_aoscalartensor(philox_seed),
mk_aoscalartensor(philox_offset),
0,
is_causal,
stream);
}
} else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
dq.zero_();

View File

@ -51,8 +51,8 @@ mha_fwd_aot(
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const bool return_softmax,
const std::optional<at::Generator>& gen_);
@ -87,8 +87,8 @@ mha_varlen_fwd_aot(
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const bool return_softmax,
const std::optional<at::Generator>& gen_);
@ -110,8 +110,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_aot(
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const bool deterministic,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset);
@ -141,8 +141,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_aot(
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const bool deterministic,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset);
@ -290,14 +290,16 @@ mha_fwd(
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
@ -307,27 +309,13 @@ mha_fwd(
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
} else {
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
#else
#endif
return mha_fwd_aot(
q,
k,
@ -341,7 +329,6 @@ mha_fwd(
window_size_right,
return_softmax,
gen_);
#endif
}
inline std::tuple<
@ -376,8 +363,8 @@ mha_varlen_fwd(
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
@ -385,6 +372,8 @@ mha_varlen_fwd(
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
return mha_varlen_fwd_ck(
q,
k,
@ -399,34 +388,13 @@ mha_varlen_fwd(
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
} else {
return mha_varlen_fwd_aot(
q,
k,
v,
out_,
cu_seqlens_q,
cu_seqlens_k,
seqused_k,
block_table_,
alibi_slopes_,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
#else
#endif
return mha_varlen_fwd_aot(
q,
k,
@ -447,7 +415,6 @@ mha_varlen_fwd(
window_size_right,
return_softmax,
gen_);
#endif
}
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
@ -468,16 +435,18 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
const float p_dropout, // probability to drop
const float softmax_scale,
const bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
#if defined(USE_CK_FLASH_ATTENTION)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
#if defined(USE_CK_FLASH_ATTENTION)
std::optional<at::Tensor> non_null_dbias = std::nullopt;
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
auto[dQuery,
dKey,
dValue,
@ -498,38 +467,16 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
non_null_window_left,
non_null_window_right,
deterministic,
philox_seed,
philox_offset);
// for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
} else {
return mha_bwd_aot(
dout,
q,
k,
v,
out,
softmax_lse,
dq_,
dk_,
dv_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
deterministic,
philox_seed,
philox_offset);
}
#else
if(at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
#endif
}
return mha_bwd_aot(
dout,
@ -550,7 +497,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
deterministic,
philox_seed,
philox_offset);
#endif
}
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd(
@ -578,8 +524,8 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
int window_size_left,
int window_size_right,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool deterministic,
const at::Tensor philox_seed,
@ -588,6 +534,8 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
std::optional<at::Tensor> non_null_dbias = std::nullopt;
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
auto[dQuery,
dKey,
dValue,
@ -613,40 +561,15 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
non_null_window_left,
non_null_window_right,
deterministic,
philox_seed,
philox_offset);
// for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
} else {
return mha_varlen_bwd_aot(
dout,
q,
k,
v,
out,
softmax_lse,
dq_,
dk_,
dv_,
cu_seqlens_q,
cu_seqlens_k,
alibi_slopes_,
max_seqlen_q,
max_seqlen_k,
p_dropout,
softmax_scale,
zero_tensors,
is_causal,
window_size_left,
window_size_right,
deterministic,
philox_seed,
philox_offset);
}
#else
#endif
return mha_varlen_bwd_aot(
dout,
q,
@ -671,7 +594,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
deterministic,
philox_seed,
philox_offset);
#endif
}
} // namespace pytorch_flash

View File

@ -1,16 +1,3 @@
macro(get_target_gpus_from_pytorch target_gpus)
set(gfx90a_key MI200)
set(gfx942_key MI300X)
set(gfx1100_key Navi31)
foreach(X IN LISTS PYTORCH_ROCM_ARCH)
set(key ${X})
string(APPEND key "_key")
string(APPEND target_gpus ${${key}})
string(APPEND target_gpus "|")
endforeach()
endmacro()
if(NOT __AOTRITON_INCLUDED)
set(__AOTRITON_INCLUDED TRUE)
@ -22,22 +9,22 @@ if(NOT __AOTRITON_INCLUDED)
# Replaces .ci/docker/aotriton_version.txt
# Note packages information may have versions skipped (due to no ABI breaks)
# But they must be listed from lower version to higher version
set(__AOTRITON_VER "0.9.2b")
set(__AOTRITON_VER "0.10b")
set(__AOTRITON_MANYLINUX_LIST
"manylinux_2_28" # rocm6.2
"manylinux_2_28" # rocm6.3
"manylinux_2_28" # rocm6.4
"manylinux_2_28" # rocm7.0
)
set(__AOTRITON_ROCM_LIST
"rocm6.2"
"rocm6.3"
"rocm6.4"
"rocm7.0"
)
set(__AOTRITON_CI_COMMIT "b388d223d8c7213545603e00f6f3148c54d1f525")
set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477")
set(__AOTRITON_SHA256_LIST
"08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2
"9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3
"41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4
"861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3
"acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4
"7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm7.0
)
set(__AOTRITON_Z "gz")
@ -50,17 +37,13 @@ if(NOT __AOTRITON_INCLUDED)
set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}")
message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}")
elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE})
set(target_gpus "")
get_target_gpus_from_pytorch(target_gpus)
ExternalProject_Add(aotriton_external
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_TAG ${__AOTRITON_CI_COMMIT}
PREFIX ${__AOTRITON_EXTERN_PREFIX}
INSTALL_DIR ${__AOTRITON_INSTALL_DIR}
LIST_SEPARATOR |
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
-DTARGET_GPUS:STRING=${target_gpus}
-DAOTRITON_COMPRESS_KERNEL=ON
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NO_SHARED=OFF

View File

@ -3294,7 +3294,7 @@ class TestSDPACudaOnly(NNTestCase):
fudge_factors['grad_key'] = 70.0
if seq_len_k >= 2048:
fudge_factors['grad_key'] = 160.0
fudge_factors['grad_query'] = 650.0
fudge_factors['grad_query'] = 670.0
if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0
@ -3415,7 +3415,7 @@ class TestSDPACudaOnly(NNTestCase):
fudge_factors['grad_key'] = 70.0
if seq_len_k >= 2048:
fudge_factors['grad_key'] = 160.0
fudge_factors['grad_query'] = 650.0
fudge_factors['grad_query'] = 670.0 # gfx90a
if dtype == torch.float32:
fudge_factors['grad_key'] = 90.0

View File

@ -57,9 +57,9 @@ def CDNA2OrLater():
def evaluate_platform_supports_flash_attention():
if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100"]
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
arch_list += ["gfx1201", "gfx950"]
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
return evaluate_gfx_arch_within(arch_list)
if TEST_CUDA:
return not IS_WINDOWS and SM80OrLater
@ -67,9 +67,9 @@ def evaluate_platform_supports_flash_attention():
def evaluate_platform_supports_efficient_attention():
if TEST_WITH_ROCM:
arch_list = ["gfx90a", "gfx942", "gfx1100"]
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
arch_list += ["gfx1201", "gfx950"]
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
return evaluate_gfx_arch_within(arch_list)
if TEST_CUDA:
return True