Compare commits

...

32 Commits

Author SHA1 Message Date
d79d17bc2c fix format 2025-09-24 05:33:35 +00:00
53a367ed24 fix format 2025-09-24 05:22:43 +00:00
d53d01e98d enable non-avx2 machines 2025-09-24 01:46:15 +00:00
91db5c8ad7 add ut for 1 partition 2025-09-24 01:43:38 +00:00
3d46e832cd fix ut 2025-09-24 01:43:38 +00:00
b81c5aee57 fix ut 2025-09-24 01:43:38 +00:00
b56ad1d8e0 fix ut 2025-09-24 01:43:38 +00:00
7de906c1e0 update and rebase 2025-09-24 01:43:38 +00:00
17e3631b3f fix format 2025-09-24 01:43:38 +00:00
e7ba15c391 enable more dtypes in ut 2025-09-24 01:43:38 +00:00
dba453f5f8 update flash decoding codes 2025-09-24 01:43:38 +00:00
198bff0681 add skip cpu 2025-09-24 01:43:38 +00:00
03503e975b remove ut change 2025-09-24 01:43:38 +00:00
7fe5e2d01d refine code 2025-09-24 01:43:38 +00:00
5c415f7ef6 refine code 2025-09-24 01:43:38 +00:00
493f6aac2c fix typo 2025-09-24 01:43:38 +00:00
8b868da737 add comments 2025-09-24 01:43:38 +00:00
208b9db2ed update 2025-09-24 01:43:38 +00:00
064fbcbcda fix ut issue 2025-09-24 01:43:38 +00:00
64dae88032 fix format 2025-09-24 01:43:38 +00:00
73c9f888aa fix format 2025-09-24 01:43:38 +00:00
ff5fd980cc fix format 2025-09-24 01:43:38 +00:00
9accbd0b70 add ut for partition size 2025-09-24 01:43:38 +00:00
1c81676302 update 2025-09-24 01:43:38 +00:00
fb1a9f6d8e update 2025-09-24 01:43:38 +00:00
b0e48e2bb7 fix format 2025-09-24 01:43:38 +00:00
0dfe0ba5e9 fix format 2025-09-24 01:43:38 +00:00
679d23190e fix format 2025-09-24 01:43:38 +00:00
46fa2893b4 fix format 2025-09-24 01:43:38 +00:00
dd8736f62c fix format 2025-09-24 01:43:38 +00:00
c95f75f95c update 2025-09-24 01:43:37 +00:00
056d4ca2b6 support flash decoding for cpu 2025-09-24 01:43:37 +00:00
4 changed files with 478 additions and 32 deletions

View File

@ -27,6 +27,7 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, with_tf3
from torch.testing._internal.common_device_type import (
flex_attention_supported_platform as supported_platform,
instantiate_device_type_tests,
skipCUDAIf,
skipXPUIf,
)
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS
@ -85,7 +86,7 @@ else:
LONG_COMPILATION_ON_CPU = True
test_dtypes = (
[torch.float32, torch.bfloat16]
[torch.float32, torch.bfloat16, torch.float16]
if torch.backends.mkldnn.is_available()
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
else [torch.float32]
@ -577,6 +578,7 @@ class TestFlexDecoding(InductorTestCase):
v: Tensor,
dtype: torch.dtype = torch.float16,
block_mask: Optional[BlockMask] = None,
kernel_options: Optional[dict] = None,
device="cuda",
):
Q_B, Q_H, KV_H = q.shape[0], q.shape[1], k.shape[1]
@ -607,6 +609,7 @@ class TestFlexDecoding(InductorTestCase):
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
else:
compiled_lse = None
@ -618,6 +621,7 @@ class TestFlexDecoding(InductorTestCase):
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
return compiled_out, compiled_lse
@ -634,6 +638,7 @@ class TestFlexDecoding(InductorTestCase):
KV_S: int = S,
V_D: int = D,
block_mask: Optional[BlockMask] = None,
kernel_options: Optional[dict] = None,
device="cuda",
):
assert Q_H % KV_H == 0
@ -670,7 +675,14 @@ class TestFlexDecoding(InductorTestCase):
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
compiled_out, compiled_lse = self.run_paged_attention(
score_mod, q, k, v, dtype, block_mask, device
score_mod,
q,
k,
v,
dtype,
block_mask,
device=device,
kernel_options=kernel_options,
)
self._check_out(
@ -737,7 +749,7 @@ class TestFlexDecoding(InductorTestCase):
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device)
compiled_out, _ = self.run_paged_attention(
score_mod, q, k, v, dtype, block_mask, device
score_mod, q, k, v, dtype, block_mask, device=device
)
self._check_out(
@ -1570,6 +1582,23 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
):
flex_attention(query, key, value, _identity)
@supported_platform
@skipCUDAIf(True, "Not supported on CUDA")
@skipXPUIf(True, "Not supported on XPU")
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("partition_size", [64, 128, 256, 1024])
def test_flash_decoding_partition_size(self, device, dtype, partition_size):
def score_mod(score, b, h, m, n):
return score * 2
self.run_test_with_paged_attention(
score_mod,
dtype,
KV_S=64,
device=device,
kernel_options={"PARTITION_SIZE": partition_size},
)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune(self, device):

View File

@ -208,7 +208,7 @@ ALLOCATE_BUFFER = r"""
{{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr;
"""
FLEX_ATTENTION_TEMPLATE = r"""
INIT_PARAMS = r"""
{{template.header().getvalue()}}
#include <ATen/native/cpu/utils.h>
#include <ATen/native/CPUBlas.h>
@ -225,16 +225,18 @@ extern "C"
{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}}
{
{{ kernel.maybe_codegen_profile() }}
int64_t qBlockSize = {{qBlockSize}};
int64_t kvBlockSize = {{kvBlockSize}};
int64_t num_thread = {{num_thread}};
// dtypes of kernel and internal buffers
// dtypes
using scalar_t = {{kernel.dtype(query)}};
constexpr bool is_reduced_type = c10::is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<{{kernel.dtype(query)}}>;
using Vec = at::vec::Vectorized<accum_t>;
accum_t scaling_factor = {{scale}};
// sizes
int64_t qBlockSize = {{qBlockSize}};
int64_t kvBlockSize = {{kvBlockSize}};
int64_t num_thread = {{num_thread}};
int64_t batchSize = {{kernel.size(query, 0)}};
int64_t qSize = {{kernel.size(query, 1)}};
int64_t num_head = {{kernel.size(query, 2)}};
@ -255,6 +257,18 @@ extern "C"
int64_t gqa_shards_kvi = num_head / num_head_kvi;
int64_t bs_shards_kvi = batchSize / batchSize_kvi;
int64_t kvSize = {{kernel.size(key, 1)}};
int64_t qSplitSize = qBlockSize;
int64_t kvSplitSize = kvBlockSize;
qSplitSize = qSplitSize > qSize ? qSize : qSplitSize;
kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize;
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
// Strides
int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}};
int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}};
int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}};
@ -276,7 +290,6 @@ extern "C"
auto kv_num_blocks_data = kv_num_blocks;
auto kv_indices_data = kv_indices;
// Strides
int64_t qStrideB = {{kernel.stride(query, 0)}};
int64_t qStrideM = {{kernel.stride(query, 1)}};
int64_t qStrideH = {{kernel.stride(query, 2)}};
@ -290,18 +303,15 @@ extern "C"
int64_t oStrideM = {{kernel.stride(output, 2)}};
int64_t oStrideH = {{kernel.stride(output, 1)}};
int64_t kvSize = {{kernel.size(key, 1)}};
// Inputs/outputs buffers
const scalar_t* q_data = query;
const scalar_t* k_data = key;
const scalar_t* v_data = value;
scalar_t* out_data = output;
int64_t qSplitSize = qBlockSize;
int64_t kvSplitSize = kvBlockSize;
qSplitSize = qSplitSize > qSize ? qSize : qSplitSize;
kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize;
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
"""
FLEX_ATTENTION_TEMPLATE = r"""
bool need_pack = false;
// Whether pack is needed for BFloat16/Half
if (is_reduced_type) {
@ -335,12 +345,6 @@ extern "C"
/* qk_sum */ qSplitSize +
/* dst */ qSplitSize * headSize_v;
// Inputs/outputs buffers
const scalar_t* q_data = query;
const scalar_t* k_data = key;
const scalar_t* v_data = value;
scalar_t* out_data = output;
// Buffers to store accum results, padding query and transpose/packing key/value
{{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}}
{{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}}
@ -683,8 +687,389 @@ extern "C"
}
"""
FLEX_DECODING_TEMPLATE = r"""
int64_t PARTITION_SIZE = {{partition_size}};
// Check if score / mask mod dependent on batch_size / num_head
// Go into a fast path if independent
bool bs_head_independent_mod = true;
int64_t first_num_kvblocks = kv_num_blocks[0];
int64_t first_full_num_kvblocks = full_kv_num_blocks[0];
for (const auto& b : c10::irange(batchSize_kvi)) {
for (const auto& h : c10::irange(num_head_kvi)) {
if (*(kv_num_blocks + b * num_kviStrideB + h * num_kviStrideH) != first_num_kvblocks
|| *(full_kv_num_blocks + b * full_num_kviStrideB + h * full_num_kviStrideH) != first_full_num_kvblocks) {
bs_head_independent_mod = false;
break;
}
}
}
int64_t num_kvblocks_per_seq = kv_num_blocks[0] + full_kv_num_blocks[0];
int64_t num_kvblocks_per_partition = PARTITION_SIZE / kvBlockSize;
int64_t num_partitions = (num_kvblocks_per_seq + num_kvblocks_per_partition - 1) / num_kvblocks_per_partition;
if (!bs_head_independent_mod) {
num_partitions =
(kvSize + PARTITION_SIZE - 1) / PARTITION_SIZE;
}
// Allocate temp buf (accumulate type)
int64_t _accum_buff_size =
/* max_logits_ptr */ batchSize * num_head * num_partitions +
/* exp_sum_ptr */ batchSize * num_head * num_partitions +
/* tmp_out_ptr */ batchSize * num_head * num_partitions * headSize_v +
/* logits_ptrs */ num_thread * PARTITION_SIZE;
{{template.codegen_allocate_buffer("buf_data", "accum_t", "_accum_buff_size")}}
accum_t* max_logits_ptr = buf_data;
accum_t* exp_sum_ptr = max_logits_ptr + batchSize * num_head * num_partitions;
accum_t* tmp_out_ptr = exp_sum_ptr + batchSize * num_head * num_partitions;
accum_t* logits_ptrs = tmp_out_ptr + batchSize * num_head * num_partitions * headSize_v;
{{template.codegen_allocate_buffer("logits_reduced_ptrs", "scalar_t", "num_thread * PARTITION_SIZE")}}
auto max_logits_strideN = num_head * num_partitions;
auto max_logits_strideH = num_partitions;
auto exp_sum_strideN = num_head * num_partitions;
auto exp_sum_strideH = num_partitions;
auto tmp_out_strideN = num_head * num_partitions * headSize_v;
auto tmp_out_strideH = num_partitions * headSize_v;
auto tmp_out_strideS = headSize_v;
// Attention loop
at::parallel_for(0, batchSize * num_head * num_partitions, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, partition_id = 0;
at::native::data_index_init(begin, i, batchSize, j, num_head, partition_id, num_partitions);
int ompIdx = at::get_thread_num();
accum_t* logits = logits_ptrs + ompIdx * PARTITION_SIZE;
scalar_t* logits_reduced =
is_reduced_type
? logits_reduced_ptrs + ompIdx * PARTITION_SIZE
: nullptr;
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
auto kvblock_offset = num_kvblocks_per_partition * partition_id;
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
auto kv_logical_num_data = kv_num_blocks_data + i_kvi * num_kviStrideB +
j_kvi * num_kviStrideH;
int64_t kv_indice_num = *kv_logical_num_data;
std::vector<int64_t> kv_indice_list(kv_indice_num);
for(int64_t kv_i = 0; kv_i < kv_indice_num; kv_i++){
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
j_kvi * kviStrideH + kv_i;
kv_indice_list[kv_i] = *kv_logical_data;
}
{%- if has_full_kv_block %}
auto full_kv_logical_num_data = full_kv_num_blocks_data + i_kvi * num_kviStrideB +
j_kvi * num_kviStrideH;
int64_t full_kv_indice_num = *full_kv_logical_num_data;
std::vector<int64_t> full_kv_indice_list(full_kv_indice_num);
for(int64_t kv_i = 0; kv_i < full_kv_indice_num; kv_i++){
auto full_kv_logical_data = full_kv_indices_data + i_kvi * full_kviStrideB +
j_kvi * full_kviStrideH + kv_i;
full_kv_indice_list[kv_i] = *full_kv_logical_data;
}
{%- endif %}
int64_t cur_qSplitSize = 1;
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
accum_t* tmp_out = tmp_out_ptr + i * tmp_out_strideN +
j * tmp_out_strideH + partition_id * tmp_out_strideS;
// Initialize logits
{{kernel.kernel_name}}_fill_stub(logits,
static_cast<accum_t>(0), PARTITION_SIZE);
if (is_reduced_type) {
{{kernel.kernel_name}}_fill_stub(logits_reduced,
static_cast<scalar_t>(0), PARTITION_SIZE);
}
// 1) calculate the matmul(query, key) for this partition
int64_t token_num = 0;
{%- if has_full_kv_block %}
int64_t n_idx_start = kvblock_offset;
int64_t n_idx_end = std::min(kvblock_offset + num_kvblocks_per_partition, kv_indice_num + full_kv_indice_num);
if (!bs_head_independent_mod) {
n_idx_start = 0;
n_idx_end = kv_indice_num + full_kv_indice_num;
}
for (int64_t n_idx : c10::irange(n_idx_start, n_idx_end)) {
auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize;
{%- else %}
int64_t n_idx_start = kvblock_offset;
int64_t n_idx_end = std::min(kvblock_offset + num_kvblocks_per_partition, kv_indice_num);
if (!bs_head_independent_mod) {
n_idx_start = 0;
n_idx_end = kv_indice_num;
}
for (int64_t n_idx : c10::irange(n_idx_start, n_idx_end)) {
auto n = kv_indice_list[n_idx]*kvSplitSize;
{%- endif %}
if (!bs_head_independent_mod
&& (n < partition_id * PARTITION_SIZE
|| n >= std::min(partition_id * PARTITION_SIZE + PARTITION_SIZE, kvSize))) {
continue;
}
auto cur_n = n/kvSplitSize;
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
auto k_addr =
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
{{kernel.kernel_name}}_kernel_micro_gemm_transpose_b<false>(
q_data + i * qStrideB + j * qStrideH,
k_addr,
logits + token_num,
cur_qSplitSize,
cur_kvSplitSize,
headSize,
qStrideM,
kStrideN,
cur_kvSplitSize);
{{kernel.kernel_name}}_mul_scale_kernel<accum_t>(logits + token_num, scaling_factor, cur_qSplitSize*cur_kvSplitSize);
{%- if score_mod and mask_mod %}
// TODO: reduce the number of calls of q_idx and kv_idx initialization
std::vector<int64_t> q_idx(cur_qSplitSize);
for (int64_t i = 0; i < cur_qSplitSize; ++i) {
q_idx[i] = i;
}
std::vector<int64_t> kv_idx(cur_kvSplitSize);
for (int64_t i = 0; i < cur_kvSplitSize; ++i) {
kv_idx[i] = n + i;
}
std::vector<int64_t> b_idx = {i};
std::vector<int64_t> h_idx = {j};
accum_t* in_ptr0 = logits + token_num;
const auto in_ptr1 = b_idx.data();
const auto in_ptr2 = h_idx.data();
const auto in_ptr3 = q_idx.data();
const auto in_ptr4 = kv_idx.data();
// apply score mod function
{
{{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }}
accum_t* out_ptr{{score_buf_idx}} = in_ptr0;
{{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }}
}
if ((std::find(kv_indice_list.begin(), kv_indice_list.end(), cur_n) != kv_indice_list.end()) ){
// Apply block mask, fill unused with -inf
{
{{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }}
accum_t* out_ptr{{mask_buf_idx}} = in_ptr0;
{{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }}
}
}
{%- endif %}
token_num += cur_kvSplitSize;
}
// 2) calculate the max and exp_sum for this partition
auto partition_max = -std::numeric_limits<float>::infinity();
{{kernel.kernel_name}}_mul_reduce_max_fusion_kernel(
logits,
static_cast<accum_t>(1),
token_num,
logits,
partition_max);
if (partition_max == -std::numeric_limits<float>::infinity()) {
partition_max = 0;
}
max_logits_ptr[i * max_logits_strideN +
j * max_logits_strideH + partition_id] =
partition_max;
{{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel(
logits,
token_num,
{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced),
partition_max);
exp_sum_ptr[i * exp_sum_strideN +
j * exp_sum_strideH + partition_id] = partition_max;
// 3) calculate the matmul(exp(logits-partition_max), value) for this
// partition, need to divide the global exp_sum in the final result.
token_num = 0;
bool skipped_partition = true;
{%- if has_full_kv_block %}
n_idx_start = kvblock_offset;
n_idx_end = std::min(kvblock_offset + num_kvblocks_per_partition, kv_indice_num + full_kv_indice_num);
if (!bs_head_independent_mod) {
n_idx_start = 0;
n_idx_end = kv_indice_num + full_kv_indice_num;
}
for (int64_t n_idx : c10::irange(n_idx_start, n_idx_end)) {
auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize;
{%- else %}
n_idx_start = kvblock_offset;
n_idx_end = std::min(kvblock_offset + num_kvblocks_per_partition, kv_indice_num);
if (!bs_head_independent_mod) {
n_idx_start = 0;
n_idx_end = kv_indice_num;
}
for (int64_t n_idx : c10::irange(n_idx_start, n_idx_end)) {
auto n = kv_indice_list[n_idx]*kvSplitSize;
{%- endif %}
if (!bs_head_independent_mod
&& (n < partition_id * PARTITION_SIZE
|| n >= std::min(partition_id * PARTITION_SIZE + PARTITION_SIZE, kvSize))) {
continue;
}
skipped_partition = false;
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
auto v_addr =
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
// Fallback Half brgemm is slower than micro gemm
if constexpr (!std::is_same_v<scalar_t, at::Half>) {
at::native::cpublas::brgemm(
cur_qSplitSize,
headSize_v,
cur_kvSplitSize,
cur_kvSplitSize,
vStrideN,
headSize_v,
token_num > 0,
{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced) + token_num,
v_addr,
tmp_out,
false);
} else {
if (token_num > 0) {
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(true)>(
{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced) + token_num,
v_addr,
tmp_out,
cur_qSplitSize,
headSize_v,
cur_kvSplitSize,
cur_kvSplitSize,
vStrideN,
headSize_v);
} else {
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced) + token_num,
v_addr,
tmp_out,
cur_qSplitSize,
headSize_v,
cur_kvSplitSize,
cur_kvSplitSize,
vStrideN,
headSize_v);
}
}
token_num += cur_kvSplitSize;
}
if (skipped_partition) {
{{kernel.kernel_name}}_fill_stub(tmp_out,
static_cast<accum_t>(0), headSize_v);
}
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, partition_id, num_partitions);
}
if constexpr (!std::is_same_v<scalar_t, at::Half>) {
at::native::cpublas::brgemm_release();
}
});
// Calculate the final output
at::parallel_for(0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0;
at::native::data_index_init(begin, i, batchSize, j, num_head);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
auto global_max = -std::numeric_limits<float>::infinity();
auto global_exp_sum = 0.0;
// Calculate the global max and exp_sum for this head
global_max = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
max_logits_ptr + i * max_logits_strideN
+ j * max_logits_strideH,
num_partitions);
// Update the partition 0 result with the global max
auto partition0_out_start =
tmp_out_ptr + i * tmp_out_strideN + j * tmp_out_strideH;
auto max_logit0 = max_logits_ptr
[i * max_logits_strideN + j * max_logits_strideH];
float exp_val = std::exp(max_logit0 - global_max);
global_exp_sum =
exp_sum_ptr[i * exp_sum_strideN + j * exp_sum_strideH] *
exp_val;
at::vec::map<accum_t>(
[exp_val](Vec x) { return x * Vec(exp_val); },
partition0_out_start,
partition0_out_start,
headSize_v);
// Accumulate the partition 1 to partition n result into partition 0
if (num_partitions > 1) {
for (auto partition_id = 1; partition_id < num_partitions;
partition_id++) {
auto tmp_out_start = partition0_out_start + partition_id * tmp_out_strideS;
auto max_logit = max_logits_ptr
[i * max_logits_strideN + j * max_logits_strideH +
partition_id];
auto exp_sum = exp_sum_ptr
[i * exp_sum_strideN + j * exp_sum_strideH +
partition_id];
exp_val = std::exp(max_logit - global_max);
global_exp_sum += exp_sum * exp_val;
at::vec::map2<accum_t>(
[exp_val](Vec a, Vec b) { return a + Vec(exp_val) * b; },
partition0_out_start,
partition0_out_start,
tmp_out_start,
headSize_v);
}
}
// Rescale the partition 0 result with global exp_sum
// Sum for full masked out rows are 0, we set them to 1
// in order to avoid NaNs in the output and instead set fully
// masked out rows to 0
global_exp_sum = global_exp_sum == 0 ? 1 : global_exp_sum;
float sum_reciprocal = 1.0 / global_exp_sum;
// copy the partition 0 result into output
at::vec::map<scalar_t>(
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
out_data + i * oStrideB + j * oStrideH,
partition0_out_start,
headSize_v);
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head);
}
});
}
"""
class CppFlexAttentionTemplate(CppTemplate):
"""
CPP template based FlexAttention CPP Template.
This class supports generation of C++ code for broad attention variants,
with applying Flash Attention and Flash Decoding. It enables template-based
code synthesis according to user-defined score and mask modifications
and configuration settings.
"""
def __init__(
self,
input_nodes,
@ -694,6 +1079,7 @@ class CppFlexAttentionTemplate(CppTemplate):
mask_mod,
kv_block_size,
q_block_size,
partition_size,
has_other_buffer,
no_full_kv_block,
fake_buffers,
@ -725,6 +1111,7 @@ class CppFlexAttentionTemplate(CppTemplate):
self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None
self.kv_block_size = kv_block_size
self.q_block_size = q_block_size
self.partition_size = partition_size
self.has_other_buffer = has_other_buffer
self.no_full_kv_block = no_full_kv_block
self.other_buffer_input_offset = 2
@ -921,6 +1308,7 @@ class CppFlexAttentionTemplate(CppTemplate):
mask_mod,
kv_block_size,
q_block_size,
partition_size,
has_other_buffer,
no_full_kv_block,
fake_buffers,
@ -946,6 +1334,7 @@ class CppFlexAttentionTemplate(CppTemplate):
mask_mod=mask_mod,
kv_block_size=kv_block_size,
q_block_size=q_block_size,
partition_size=partition_size,
has_other_buffer=has_other_buffer,
no_full_kv_block=no_full_kv_block,
fake_buffers=fake_buffers,
@ -960,6 +1349,32 @@ class CppFlexAttentionTemplate(CppTemplate):
def apply_score_mod(self, score, b, h, q_idx, kv_idx):
return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item()
def choose_flex_template(
self,
query: ir.Buffer,
key: ir.Buffer,
num_threads,
):
# choose from FLEX_ATTENTION or FLEX_DECODING
FLEX_TEMPLATE = FLEX_ATTENTION_TEMPLATE
q_batch_size, q_num_heads, q_seq_len, _ = query.data.data.layout.size # type: ignore[attr-defined]
k_seq_len = key.data.data.layout.size[2] # type: ignore[attr-defined]
if all(
sympy.sympify(val).is_number
for val in [q_batch_size, q_num_heads, q_seq_len, k_seq_len, num_threads]
):
# if static shape, FLEX_DECODING will be chosen with these conditions:
# 1) partition size is multiple of kv block size, so each partition has several blocks
# 2) decoding scenario: q seq length is 1
# 3) The actual k seq length (k_seq_len / q_batch_size) is large enough
if (
self.partition_size % self.kv_block_size == 0
and q_seq_len == 1
and k_seq_len / q_batch_size >= max(self.partition_size * 2, 512)
):
FLEX_TEMPLATE = FLEX_DECODING_TEMPLATE
return FLEX_TEMPLATE
def render( # type: ignore[override,return]
self,
kernel,
@ -1017,13 +1432,17 @@ class CppFlexAttentionTemplate(CppTemplate):
mask_buf_name=self.mask_buf_name,
score_buf_idx=self.score_buf_idx,
mask_buf_idx=self.mask_buf_idx,
partition_size=self.partition_size,
)
with contextlib.ExitStack() as stack:
for buf in self.fake_buffers:
stack.enter_context(
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
)
return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options)
FLEX_TEMPLATE = self.choose_flex_template(query, key, num_threads)
return self._template_from_string(INIT_PARAMS + FLEX_TEMPLATE).render(
**options
)
def codegen_softmax_fusion(self, kernel_name: str):
# TODO: use inductor IR to rewrite those fusions

View File

@ -289,6 +289,8 @@ def lower_cpu(
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
# In flash decoding, the partition size of doing the parallelism on KV length dim
PARTITION_SIZE = kernel_options.get("PARTITION_SIZE", 128)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
), (
@ -308,6 +310,7 @@ def lower_cpu(
mask_mod=None if skip_mask_score else mask_graph_buffer,
kv_block_size=SPARSE_KV_BLOCK_SIZE,
q_block_size=SPARSE_Q_BLOCK_SIZE,
partition_size=PARTITION_SIZE,
has_other_buffer=has_other_buffer,
no_full_kv_block=no_full_kv_block,
fake_buffers=fake_buffers,

View File

@ -1970,13 +1970,8 @@ def get_all_device_types() -> list[str]:
return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
# skip since currently flex attention requires at least `avx2` support on CPU.
IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED = (
not torch.xpu.is_available()
and not torch.cuda.is_available()
and not IS_MACOS
and torch.cpu._is_avx2_supported()
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
not torch.xpu.is_available() and not torch.cuda.is_available() and not IS_MACOS
)
IS_FLEX_ATTENTION_XPU_PLATFORM_SUPPORTED = (
torch.xpu.is_available() and torch.utils._triton.has_triton()