mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm] Bump AOTriton to 0.10b (#156290)
Notable new features/optimizations for SDPA operators on AMD systems from AOTriton 0.10b: * Official support of gfx950/gfx1201 * Experimental support of gfx1101/gfx1151/gfx1150/gfx1200 * Reduce libaotriton.so binary size by over 80%. + Without this optimization the binary size of `libaotriton.so` could be over 100MiB due to 2x more supported architectures compared with 0.9b. Now it is only about 11MiB. * Support sliding window attention (SWA) in `_flash_attention_forward/backward`. Should fix #154582 See https://github.com/ROCm/aotriton/releases/tag/0.10b for full details, including Known Problems. Notable changes to SDPA backend: * `std::optional<int64_t>` `window_size_left/right` are directly passed to ROCM's SDPA backend, because the default value `-1` is meaningful to AOTriton's backend and bottom-right aligned causal mask is implemented with negative `window_size_left/right` * Some code clean up around `USE_CK_FLASH_ATTENTION` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156290 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
3644b41a7c
commit
34d8e64ef6
@ -1113,8 +1113,10 @@ _flash_attention_forward(
|
||||
std::optional<Tensor> alibi_slopes = _alibi_slopes;
|
||||
const float softcap = 0.0;
|
||||
|
||||
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
||||
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
const int non_null_window_right = window_size_right.value_or(-1);
|
||||
#endif
|
||||
|
||||
// We are going to have two paths:
|
||||
// 1. The standard MHA path for dense tensors
|
||||
@ -1151,8 +1153,13 @@ _flash_attention_forward(
|
||||
softmax_scale,
|
||||
false /*zero_tensors*/,
|
||||
is_causal,
|
||||
#ifdef USE_ROCM
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
#else
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
#endif
|
||||
softcap,
|
||||
return_debug_mask,
|
||||
std::nullopt /*gen_*/);
|
||||
@ -1175,8 +1182,13 @@ _flash_attention_forward(
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
#ifdef USE_ROCM
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
#else
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
#endif
|
||||
softcap,
|
||||
return_debug_mask, /*return_softmax (this is used for testing)*/
|
||||
std::nullopt);
|
||||
|
@ -87,8 +87,10 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
auto contiguous_grad_out = grad_out.contiguous();
|
||||
auto contiguous_out = out.contiguous();
|
||||
|
||||
#ifndef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly.
|
||||
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
||||
#endif
|
||||
|
||||
std::optional<at::Tensor> dq{std::nullopt};
|
||||
std::optional<at::Tensor> dk{std::nullopt};
|
||||
@ -136,8 +138,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
softmax_scale,
|
||||
false /*zero_tensors*/,
|
||||
is_causal,
|
||||
#ifdef USE_ROCM
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
#else
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
#endif
|
||||
softcap,
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
@ -159,8 +166,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
#ifdef USE_ROCM
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
#else
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
#endif
|
||||
softcap,
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
|
@ -64,8 +64,14 @@
|
||||
#include <aotriton/flash.h>
|
||||
#include <aotriton/runtime.h>
|
||||
|
||||
#if AOTRITON_VERSION_MINOR != 9
|
||||
#error "This adaptor code is only tested with AOTriton 0.9.x"
|
||||
#if AOTRITON_VERSION_MINOR < 9
|
||||
#error "This adaptor code is only tested with AOTriton >= 0.9"
|
||||
#endif
|
||||
|
||||
#if (AOTRITON_VERSION_MAJOR * 100 + AOTRITON_VERSION_MINOR) >= 10
|
||||
#define V3_API 1
|
||||
#else
|
||||
#define V3_API 0
|
||||
#endif
|
||||
|
||||
namespace pytorch_flash {
|
||||
@ -81,6 +87,38 @@ void check_gpu_arch(hipStream_t stream) {
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<bool, int, int>
|
||||
calculate_swa(std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
bool is_causal) {
|
||||
#if V3_API // SWA is exposed through V3 API
|
||||
bool needs_swa = false;
|
||||
using aotriton::v3::flash::WindowValue;
|
||||
// Default values when std::optional window_size_left/right have no value
|
||||
int window_left = max_seqlen_q;
|
||||
int window_right = max_seqlen_k;
|
||||
if (is_causal) {
|
||||
window_left = WindowValue::TopLeftAligned;
|
||||
window_right = WindowValue::TopLeftAligned;
|
||||
}
|
||||
if (window_size_left.has_value() || window_size_right.has_value()) {
|
||||
needs_swa = true;
|
||||
window_left = window_size_left.value_or(window_left);
|
||||
window_right = window_size_right.value_or(window_right);
|
||||
}
|
||||
return std::make_tuple(needs_swa, window_left, window_right);
|
||||
#else
|
||||
if (window_size_left.has_value() || window_size_right.has_value()) {
|
||||
TORCH_WARN_ONCE("Current AOTriton does not support sliding window attention (SWA)."
|
||||
" Both window_size_left and window_size_right will be ignored."
|
||||
" Re-compile PyTorch with AOTriton >= 0.10b to enable SWA support.");
|
||||
}
|
||||
return std::make_tuple(false, 0, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
@ -127,8 +165,8 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool return_softmax,
|
||||
const std::optional<at::Generator>& gen_) {
|
||||
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
|
||||
@ -161,7 +199,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
@ -212,6 +249,19 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
atomic_counter = at::zeros({1}, opts.dtype(at::kInt));
|
||||
}
|
||||
|
||||
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
|
||||
window_size_right,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
is_causal);
|
||||
#if V3_API
|
||||
const bool uses_swa = needs_swa;
|
||||
#else
|
||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// optimized out (hopefully).
|
||||
constexpr bool uses_swa = false;
|
||||
#endif
|
||||
|
||||
hipError_t err; // TODO: Error handling
|
||||
using aotriton::v2::flash::attn_fwd;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
@ -226,23 +276,54 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x
|
||||
auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
|
||||
auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
|
||||
auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
|
||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
empty_bias,
|
||||
softmax_scale,
|
||||
mk_aotensor<2>(M, "M"),
|
||||
mk_aotensor(output_t, "Out"),
|
||||
p_dropout,
|
||||
seed,
|
||||
offset1,
|
||||
offset2,
|
||||
seed_output,
|
||||
offset_output,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
persistent_counter,
|
||||
stream);
|
||||
if (uses_swa) {
|
||||
#if V3_API
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
aotriton::v3::flash::attn_fwd_params params;
|
||||
params.Q = mk_aotensor(q_t, "q");
|
||||
params.K = mk_aotensor(k_t, "k");
|
||||
params.V = mk_aotensor(v_t, "v");
|
||||
params.Sm_scale = softmax_scale;
|
||||
params.L = mk_aotensor<2>(M, "M");
|
||||
params.Out = mk_aotensor(output_t, "Out");
|
||||
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
|
||||
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
|
||||
params.dropout_p = p_dropout;
|
||||
params.philox_seed_ptr = seed;
|
||||
params.philox_offset1 = offset1;
|
||||
params.philox_offset2 = offset2;
|
||||
params.philox_seed_output = seed_output;
|
||||
params.philox_offset_output = offset_output;
|
||||
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
|
||||
params.persistent_atomic_counter = persistent_counter;
|
||||
params.causal_type = CausalType::WindowedAttention;
|
||||
params.varlen_type = VarlenType::None;
|
||||
params.window_left = window_left;
|
||||
params.window_right = window_right;
|
||||
err = aotriton::v3::flash::attn_fwd(params,
|
||||
aotriton::v3::flash::attn_fwd_params::kVersion,
|
||||
stream);
|
||||
#endif
|
||||
} else {
|
||||
err = attn_fwd(mk_aotensor(q_t, "q"),
|
||||
mk_aotensor(k_t, "k"),
|
||||
mk_aotensor(v_t, "v"),
|
||||
empty_bias,
|
||||
softmax_scale,
|
||||
mk_aotensor<2>(M, "M"),
|
||||
mk_aotensor(output_t, "Out"),
|
||||
p_dropout,
|
||||
seed,
|
||||
offset1,
|
||||
offset2,
|
||||
seed_output,
|
||||
offset_output,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
persistent_counter,
|
||||
stream);
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t};
|
||||
}
|
||||
@ -263,8 +344,8 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool return_softmax,
|
||||
const std::optional<at::Generator>& gen_) {
|
||||
TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt");
|
||||
@ -312,13 +393,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
TORCH_CHECK(head_size_og <= 512, "FlashAttention on ROCm forward only supports head dimension at most 512");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (window_size_left >= max_seqlen_k) {
|
||||
window_size_left = -1;
|
||||
}
|
||||
if (window_size_right >= max_seqlen_k) {
|
||||
window_size_right = -1;
|
||||
}
|
||||
|
||||
CHECK_SHAPE(temp_q, total_q, num_heads, head_size_og);
|
||||
const int total_k = k.size(0);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||
@ -368,6 +442,19 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
}
|
||||
}
|
||||
|
||||
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
|
||||
window_size_right,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
is_causal);
|
||||
#if V3_API
|
||||
const bool uses_swa = needs_swa;
|
||||
#else
|
||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// optimized out (hopefully).
|
||||
constexpr bool uses_swa = false;
|
||||
#endif
|
||||
|
||||
auto [seed_t, offset_t, philox_state, use_philox_state] =
|
||||
prepare_philox_arguments(p_dropout, batch_size * num_heads * 32);
|
||||
|
||||
@ -390,27 +477,58 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
|
||||
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : nullscalar;
|
||||
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : nullscalar;
|
||||
auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr<int64_t>()) : nullscalar;
|
||||
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||
mk_aotensor(k_padded, "k"),
|
||||
mk_aotensor(v_padded, "v"),
|
||||
empty_bias,
|
||||
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
|
||||
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
mk_aotensor<2>(M, "M"),
|
||||
mk_aotensor(out_padded, "Out"),
|
||||
p_dropout,
|
||||
seed,
|
||||
offset1,
|
||||
offset2,
|
||||
seed_output,
|
||||
offset_output,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
persistent_counter,
|
||||
stream);
|
||||
if (uses_swa) {
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
aotriton::v3::flash::attn_fwd_params params;
|
||||
params.Q = mk_aotensor(q_padded, "q");
|
||||
params.K = mk_aotensor(k_padded, "k");
|
||||
params.V = mk_aotensor(v_padded, "v");
|
||||
params.Sm_scale = softmax_scale;
|
||||
params.L = mk_aotensor<2>(M, "M");
|
||||
params.Out = mk_aotensor(out_padded, "Out");
|
||||
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q");
|
||||
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k");
|
||||
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
|
||||
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
|
||||
params.dropout_p = p_dropout;
|
||||
params.philox_seed_ptr = seed;
|
||||
params.philox_offset1 = offset1;
|
||||
params.philox_offset2 = offset2;
|
||||
params.philox_seed_output = seed_output;
|
||||
params.philox_offset_output = offset_output;
|
||||
params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax");
|
||||
params.persistent_atomic_counter = persistent_counter;
|
||||
params.causal_type = CausalType::WindowedAttention;
|
||||
params.varlen_type = VarlenType::CompactVarlen;
|
||||
params.window_left = window_left;
|
||||
params.window_right = window_right;
|
||||
err = aotriton::v3::flash::attn_fwd(params,
|
||||
aotriton::v3::flash::attn_fwd_params::kVersion,
|
||||
stream);
|
||||
} else {
|
||||
err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||
mk_aotensor(k_padded, "k"),
|
||||
mk_aotensor(v_padded, "v"),
|
||||
empty_bias,
|
||||
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
|
||||
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
softmax_scale,
|
||||
mk_aotensor<2>(M, "M"),
|
||||
mk_aotensor(out_padded, "Out"),
|
||||
p_dropout,
|
||||
seed,
|
||||
offset1,
|
||||
offset2,
|
||||
seed_output,
|
||||
offset_output,
|
||||
mk_aotensor(softmax_fa_t, "encoded_softmax"),
|
||||
is_causal,
|
||||
persistent_counter,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
out.zero_();
|
||||
@ -434,8 +552,8 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor& philox_seed,
|
||||
const at::Tensor& philox_offset) {
|
||||
@ -524,6 +642,19 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
dv = at::empty_like(k);
|
||||
}
|
||||
|
||||
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
|
||||
window_size_right,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
is_causal);
|
||||
#if V3_API
|
||||
const bool uses_swa = needs_swa;
|
||||
#else
|
||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// optimized out (hopefully).
|
||||
constexpr bool uses_swa = false;
|
||||
#endif
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
|
||||
@ -541,10 +672,40 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
int d_head = head_size_og;
|
||||
bool use_fused_bwd = d_head <= 192 && d_head * seqlen_q < 64 * 512;
|
||||
hipError_t err; // TODO: Error handling
|
||||
if (use_fused_bwd) {
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
if (uses_swa) {
|
||||
// Fused BWD does not support SWA
|
||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
aotriton::v3::flash::attn_bwd_params params;
|
||||
params.Q = mk_aotensor(q_t, "q");
|
||||
params.K = mk_aotensor(k_t, "k");
|
||||
params.V = mk_aotensor(v_t, "v");
|
||||
params.Sm_scale = softmax_scale;
|
||||
params.Out = mk_aotensor(out_t, "out");
|
||||
params.DO = mk_aotensor(dout_t, "dout");
|
||||
params.DK = mk_aotensor(dq_t, "dq");
|
||||
params.DV = mk_aotensor(dk_t, "dk");
|
||||
params.DQ = mk_aotensor(dv_t, "dv");
|
||||
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
|
||||
params.D = mk_aotensor<2>(delta, "delta");
|
||||
params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty
|
||||
params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty
|
||||
params.dropout_p = p_dropout;
|
||||
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
|
||||
params.philox_offset1 = mk_aoscalartensor(philox_offset);
|
||||
params.philox_offset2 = 0;
|
||||
params.causal_type = CausalType::WindowedAttention;
|
||||
params.varlen_type = VarlenType::None;
|
||||
params.window_left = window_left;
|
||||
params.window_right = window_right;
|
||||
err = aotriton::v3::flash::attn_bwd(params,
|
||||
aotriton::v3::flash::attn_bwd_params::kVersion,
|
||||
stream);
|
||||
} else if (use_fused_bwd) {
|
||||
using aotriton::v2::flash::attn_bwd_fused;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||
err = attn_bwd_fused(mk_aotensor(q_t, "q"),
|
||||
@ -568,8 +729,6 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea
|
||||
} else {
|
||||
at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous();
|
||||
using aotriton::v2::flash::attn_bwd;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||
err = attn_bwd(mk_aotensor(q_t, "q"),
|
||||
@ -615,17 +774,14 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor& philox_seed,
|
||||
const at::Tensor& philox_offset)
|
||||
{
|
||||
TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt");
|
||||
|
||||
if (is_causal) {
|
||||
window_size_right = 0;
|
||||
}
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
|
||||
@ -669,9 +825,6 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
TORCH_CHECK(head_size <= 512, "FlashAttention on ROCm backward only supports head dimension at most 512");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
||||
@ -734,6 +887,19 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left,
|
||||
window_size_right,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
is_causal);
|
||||
#if V3_API
|
||||
const bool uses_swa = needs_swa;
|
||||
#else
|
||||
// When V3_API = 0, uses_swa is constexpr and the if (uses_swa) branch can be
|
||||
// optimized out (hopefully).
|
||||
constexpr bool uses_swa = false;
|
||||
#endif
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
@ -747,34 +913,66 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
}
|
||||
if (max_seqlen_q > 0) {
|
||||
hipError_t err; // TODO: Error handling
|
||||
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
||||
using sdp::aotriton_adapter::mk_aotensor;
|
||||
using sdp::aotriton_adapter::mk_aoscalartensor;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||
mk_aotensor(k_padded, "k"),
|
||||
mk_aotensor(v_padded, "v"),
|
||||
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
|
||||
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
empty_bias,
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_padded, "dq"),
|
||||
mk_aotensor(dk_padded, "dk"),
|
||||
mk_aotensor(dv_padded, "dv"),
|
||||
empty_bias,
|
||||
mk_aotensor<2>(softmax_lse_cont, "L"),
|
||||
mk_aotensor<2>(delta, "delta"),
|
||||
p_dropout,
|
||||
mk_aoscalartensor(philox_seed),
|
||||
mk_aoscalartensor(philox_offset),
|
||||
0,
|
||||
is_causal,
|
||||
stream);
|
||||
if (uses_swa) {
|
||||
using aotriton::v3::flash::CausalType;
|
||||
using aotriton::v3::flash::VarlenType;
|
||||
aotriton::v3::flash::attn_bwd_params params;
|
||||
params.Q = mk_aotensor(q_padded, "q");
|
||||
params.K = mk_aotensor(k_padded, "k");
|
||||
params.V = mk_aotensor(v_padded, "v");
|
||||
params.Sm_scale = softmax_scale;
|
||||
params.Out = mk_aotensor(out_t, "out");
|
||||
params.DO = mk_aotensor(dout_t, "dout");
|
||||
params.DK = mk_aotensor(dq_padded, "dq");
|
||||
params.DV = mk_aotensor(dk_padded, "dk");
|
||||
params.DQ = mk_aotensor(dv_padded, "dv");
|
||||
params.L = mk_aotensor<2>(softmax_lse_cont, "L");
|
||||
params.D = mk_aotensor<2>(delta, "delta");
|
||||
params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q");
|
||||
params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k");
|
||||
params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty
|
||||
params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty
|
||||
params.dropout_p = p_dropout;
|
||||
params.philox_seed_ptr = mk_aoscalartensor(philox_seed);
|
||||
params.philox_offset1 = mk_aoscalartensor(philox_offset);
|
||||
params.philox_offset2 = 0;
|
||||
params.causal_type = CausalType::WindowedAttention;
|
||||
params.varlen_type = VarlenType::CompactVarlen;
|
||||
params.window_left = window_left;
|
||||
params.window_right = window_right;
|
||||
err = aotriton::v3::flash::attn_bwd(params,
|
||||
aotriton::v3::flash::attn_bwd_params::kVersion,
|
||||
stream);
|
||||
} else {
|
||||
using aotriton::v2::flash::attn_bwd_compact_varlen;
|
||||
using sdp::aotriton_adapter::cast_dtype;
|
||||
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
|
||||
err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"),
|
||||
mk_aotensor(k_padded, "k"),
|
||||
mk_aotensor(v_padded, "v"),
|
||||
mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"),
|
||||
mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"),
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
empty_bias,
|
||||
softmax_scale,
|
||||
mk_aotensor(out_t, "out"),
|
||||
mk_aotensor(dout_t, "dout"),
|
||||
mk_aotensor(dq_padded, "dq"),
|
||||
mk_aotensor(dk_padded, "dk"),
|
||||
mk_aotensor(dv_padded, "dv"),
|
||||
empty_bias,
|
||||
mk_aotensor<2>(softmax_lse_cont, "L"),
|
||||
mk_aotensor<2>(delta, "delta"),
|
||||
p_dropout,
|
||||
mk_aoscalartensor(philox_seed),
|
||||
mk_aoscalartensor(philox_offset),
|
||||
0,
|
||||
is_causal,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
dq.zero_();
|
||||
|
@ -51,8 +51,8 @@ mha_fwd_aot(
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool return_softmax,
|
||||
const std::optional<at::Generator>& gen_);
|
||||
|
||||
@ -87,8 +87,8 @@ mha_varlen_fwd_aot(
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool return_softmax,
|
||||
const std::optional<at::Generator>& gen_);
|
||||
|
||||
@ -110,8 +110,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd_aot(
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor& philox_seed,
|
||||
const at::Tensor& philox_offset);
|
||||
@ -141,8 +141,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_aot(
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor& philox_seed,
|
||||
const at::Tensor& philox_offset);
|
||||
@ -290,14 +290,16 @@ mha_fwd(
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
const int non_null_window_right = window_size_right.value_or(-1);
|
||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||
return mha_fwd_ck(
|
||||
q,
|
||||
@ -307,27 +309,13 @@ mha_fwd(
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
return_softmax,
|
||||
gen_,
|
||||
dummy_attn_bias); // Not used in flash attention
|
||||
} else {
|
||||
return mha_fwd_aot(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out_,
|
||||
alibi_slopes_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
return_softmax,
|
||||
gen_);
|
||||
}
|
||||
#else
|
||||
#endif
|
||||
return mha_fwd_aot(
|
||||
q,
|
||||
k,
|
||||
@ -341,7 +329,6 @@ mha_fwd(
|
||||
window_size_right,
|
||||
return_softmax,
|
||||
gen_);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline std::tuple<
|
||||
@ -376,8 +363,8 @@ mha_varlen_fwd(
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
@ -385,6 +372,8 @@ mha_varlen_fwd(
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
const int non_null_window_right = window_size_right.value_or(-1);
|
||||
return mha_varlen_fwd_ck(
|
||||
q,
|
||||
k,
|
||||
@ -399,34 +388,13 @@ mha_varlen_fwd(
|
||||
softmax_scale,
|
||||
zero_tensors,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
return_softmax,
|
||||
gen_,
|
||||
dummy_attn_bias); // Not used in flash attention
|
||||
} else {
|
||||
return mha_varlen_fwd_aot(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out_,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_k,
|
||||
block_table_,
|
||||
alibi_slopes_,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
zero_tensors,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
return_softmax,
|
||||
gen_);
|
||||
}
|
||||
#else
|
||||
#endif
|
||||
return mha_varlen_fwd_aot(
|
||||
q,
|
||||
k,
|
||||
@ -447,7 +415,6 @@ mha_varlen_fwd(
|
||||
window_size_right,
|
||||
return_softmax,
|
||||
gen_);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||
@ -468,16 +435,18 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
const int non_null_window_right = window_size_right.value_or(-1);
|
||||
auto[dQuery,
|
||||
dKey,
|
||||
dValue,
|
||||
@ -498,38 +467,16 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
// for FA return [dQ, dV, dK, dSoftmax]
|
||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
||||
} else {
|
||||
return mha_bwd_aot(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq_,
|
||||
dk_,
|
||||
dv_,
|
||||
alibi_slopes_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
}
|
||||
#else
|
||||
if(at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
|
||||
#endif
|
||||
}
|
||||
return mha_bwd_aot(
|
||||
dout,
|
||||
@ -550,7 +497,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd(
|
||||
@ -578,8 +524,8 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
@ -588,6 +534,8 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
const int non_null_window_right = window_size_right.value_or(-1);
|
||||
auto[dQuery,
|
||||
dKey,
|
||||
dValue,
|
||||
@ -613,40 +561,15 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||
softmax_scale,
|
||||
zero_tensors,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
// for FA return [dQ, dV, dK, dSoftmax]
|
||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
||||
} else {
|
||||
return mha_varlen_bwd_aot(
|
||||
dout,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
softmax_lse,
|
||||
dq_,
|
||||
dk_,
|
||||
dv_,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
alibi_slopes_,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
zero_tensors,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
}
|
||||
#else
|
||||
#endif
|
||||
return mha_varlen_bwd_aot(
|
||||
dout,
|
||||
q,
|
||||
@ -671,7 +594,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||
deterministic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace pytorch_flash
|
||||
|
33
cmake/External/aotriton.cmake
vendored
33
cmake/External/aotriton.cmake
vendored
@ -1,16 +1,3 @@
|
||||
macro(get_target_gpus_from_pytorch target_gpus)
|
||||
set(gfx90a_key MI200)
|
||||
set(gfx942_key MI300X)
|
||||
set(gfx1100_key Navi31)
|
||||
|
||||
foreach(X IN LISTS PYTORCH_ROCM_ARCH)
|
||||
set(key ${X})
|
||||
string(APPEND key "_key")
|
||||
string(APPEND target_gpus ${${key}})
|
||||
string(APPEND target_gpus "|")
|
||||
endforeach()
|
||||
endmacro()
|
||||
|
||||
if(NOT __AOTRITON_INCLUDED)
|
||||
set(__AOTRITON_INCLUDED TRUE)
|
||||
|
||||
@ -22,22 +9,22 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
# Replaces .ci/docker/aotriton_version.txt
|
||||
# Note packages information may have versions skipped (due to no ABI breaks)
|
||||
# But they must be listed from lower version to higher version
|
||||
set(__AOTRITON_VER "0.9.2b")
|
||||
set(__AOTRITON_VER "0.10b")
|
||||
set(__AOTRITON_MANYLINUX_LIST
|
||||
"manylinux_2_28" # rocm6.2
|
||||
"manylinux_2_28" # rocm6.3
|
||||
"manylinux_2_28" # rocm6.4
|
||||
"manylinux_2_28" # rocm7.0
|
||||
)
|
||||
set(__AOTRITON_ROCM_LIST
|
||||
"rocm6.2"
|
||||
"rocm6.3"
|
||||
"rocm6.4"
|
||||
"rocm7.0"
|
||||
)
|
||||
set(__AOTRITON_CI_COMMIT "b388d223d8c7213545603e00f6f3148c54d1f525")
|
||||
set(__AOTRITON_CI_COMMIT "6fca155f4deeb8d9529326f7b69f350aeeb93477")
|
||||
set(__AOTRITON_SHA256_LIST
|
||||
"08d84f96f4c984179f80f517c0431c7511ee26bb0ce9bd05a827573ddd78cc79" # rocm6.2
|
||||
"9094d59717e7e6eace9126ca100dd0e86510f07fc6c3a349569fc4e2d9056604" # rocm6.3
|
||||
"41190202c2736d5ff75b13a3abc0fb52ebfbb67226cf85dc3de7699c7000db44" # rocm6.4
|
||||
"861cd9f7479eec943933c27cb86920247e5b5dd139bc7c1376c81808abb7d7fe" # rocm6.3
|
||||
"acea7d811a2d3bbe718b6e07fc2a9f739e49eecd60b4b6a36fcb3fe8edf85d78" # rocm6.4
|
||||
"7e29c325d5bd33ba896ddb106f5d4fc7d715274dca7fe937f724fffa82017838" # rocm7.0
|
||||
)
|
||||
set(__AOTRITON_Z "gz")
|
||||
|
||||
@ -50,17 +37,13 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
set(__AOTRITON_INSTALL_DIR "$ENV{AOTRITON_INSTALLED_PREFIX}")
|
||||
message(STATUS "Using Preinstalled AOTriton at ${__AOTRITON_INSTALL_DIR}")
|
||||
elseif(DEFINED ENV{AOTRITON_INSTALL_FROM_SOURCE})
|
||||
set(target_gpus "")
|
||||
get_target_gpus_from_pytorch(target_gpus)
|
||||
ExternalProject_Add(aotriton_external
|
||||
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
||||
GIT_TAG ${__AOTRITON_CI_COMMIT}
|
||||
PREFIX ${__AOTRITON_EXTERN_PREFIX}
|
||||
INSTALL_DIR ${__AOTRITON_INSTALL_DIR}
|
||||
LIST_SEPARATOR |
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${__AOTRITON_INSTALL_DIR}
|
||||
-DTARGET_GPUS:STRING=${target_gpus}
|
||||
-DAOTRITON_COMPRESS_KERNEL=ON
|
||||
-DAOTRITON_TARGET_ARCH:STRING=${PYTORCH_ROCM_ARCH}
|
||||
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
|
||||
-DAOTRITON_NO_PYTHON=ON
|
||||
-DAOTRITON_NO_SHARED=OFF
|
||||
|
@ -3294,7 +3294,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
fudge_factors['grad_key'] = 70.0
|
||||
if seq_len_k >= 2048:
|
||||
fudge_factors['grad_key'] = 160.0
|
||||
fudge_factors['grad_query'] = 650.0
|
||||
fudge_factors['grad_query'] = 670.0
|
||||
if dtype == torch.float32:
|
||||
fudge_factors['grad_key'] = 90.0
|
||||
|
||||
@ -3415,7 +3415,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
fudge_factors['grad_key'] = 70.0
|
||||
if seq_len_k >= 2048:
|
||||
fudge_factors['grad_key'] = 160.0
|
||||
fudge_factors['grad_query'] = 650.0
|
||||
fudge_factors['grad_query'] = 670.0 # gfx90a
|
||||
if dtype == torch.float32:
|
||||
fudge_factors['grad_key'] = 90.0
|
||||
|
||||
|
@ -57,9 +57,9 @@ def CDNA2OrLater():
|
||||
|
||||
def evaluate_platform_supports_flash_attention():
|
||||
if TEST_WITH_ROCM:
|
||||
arch_list = ["gfx90a", "gfx942", "gfx1100"]
|
||||
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
|
||||
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
|
||||
arch_list += ["gfx1201", "gfx950"]
|
||||
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
|
||||
return evaluate_gfx_arch_within(arch_list)
|
||||
if TEST_CUDA:
|
||||
return not IS_WINDOWS and SM80OrLater
|
||||
@ -67,9 +67,9 @@ def evaluate_platform_supports_flash_attention():
|
||||
|
||||
def evaluate_platform_supports_efficient_attention():
|
||||
if TEST_WITH_ROCM:
|
||||
arch_list = ["gfx90a", "gfx942", "gfx1100"]
|
||||
arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"]
|
||||
if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0":
|
||||
arch_list += ["gfx1201", "gfx950"]
|
||||
arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"]
|
||||
return evaluate_gfx_arch_within(arch_list)
|
||||
if TEST_CUDA:
|
||||
return True
|
||||
|
Reference in New Issue
Block a user