This commit is contained in:
Valentine233
2025-08-25 01:58:23 +00:00
parent b0e48e2bb7
commit fb1a9f6d8e

View File

@ -783,8 +783,9 @@ extern "C"
const scalar_t* v_data = value;
scalar_t* out_data = output;
auto actual_kvSize = kvSize / batchSize;
auto num_partitions =
(kvSize + PARTITION_SIZE - 1) / PARTITION_SIZE;
(actual_kvSize + PARTITION_SIZE - 1) / PARTITION_SIZE;
// Allocate temp buf (accumulate type)
int64_t _accum_buff_size =
@ -808,6 +809,7 @@ extern "C"
auto tmp_out_strideH = num_partitions * headSize_v;
auto tmp_out_strideS = headSize_v;
// TODO: For GQA, the parallelism dim of q_num_head can be changed to kv_num_head
// Attention loop
at::parallel_for(0, batchSize * num_head * num_partitions, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, partition_id = 0;
@ -820,7 +822,8 @@ extern "C"
: nullptr;
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
auto partition_start = partition_id * PARTITION_SIZE;
auto n_offset = i * actual_kvSize;
auto partition_start = n_offset + partition_id * PARTITION_SIZE;
auto partition_end =
std::min(partition_start + PARTITION_SIZE, kvSize);
@ -930,9 +933,9 @@ extern "C"
{{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }}
}
}
{%- endif %}
token_num += cur_kvSplitSize;
}
{%- endif %}
// 2) calculate the max and exp_sum for this partition
@ -1044,13 +1047,11 @@ extern "C"
auto global_exp_sum = 0.0;
// Calculate the global max and exp_sum for this head
for (auto partition_id = 0; partition_id < num_partitions;
partition_id++) {
auto max_logit = max_logits_ptr
[i * max_logits_strideN +
j * max_logits_strideH + partition_id];
global_max = std::max(global_max, max_logit);
}
global_max = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
max_logits_ptr + i * max_logits_strideN
+ j * max_logits_strideH,
num_partitions);
// Update the partition 0 result with the global max
auto partition0_out_start =
@ -1058,7 +1059,7 @@ extern "C"
auto max_logit0 = max_logits_ptr
[i * max_logits_strideN + j * max_logits_strideH];
float exp_val = std::exp(max_logit0 - global_max);
global_exp_sum +=
global_exp_sum =
exp_sum_ptr[i * exp_sum_strideN + j * exp_sum_strideH] *
exp_val;
at::vec::map<accum_t>(
@ -1400,6 +1401,29 @@ class CppFlexAttentionTemplate(CppTemplate):
def apply_score_mod(self, score, b, h, q_idx, kv_idx):
return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item()
def choose_flex_template(
self,
query: ir.Buffer,
num_threads,
):
# choose from FLEX_ATTENTION or FLEX_DECODING
FLEX_TEMPLATE = FLEX_ATTENTION_TEMPLATE
q_batch_size = query.data.data.layout.size[0]
q_num_heads = query.data.data.layout.size[1]
q_seq_len = query.data.data.layout.size[2]
if all(
sympy.sympify(val).is_number
for val in [q_batch_size, q_num_heads, q_seq_len, num_threads]
):
# if static shape
if (
self.partition_size % self.kv_block_size == 0
and q_seq_len == 1
and num_threads > q_batch_size * q_num_heads
):
FLEX_TEMPLATE = FLEX_DECODING_TEMPLATE
return FLEX_TEMPLATE
def render( # type: ignore[override,return]
self,
kernel,
@ -1464,19 +1488,10 @@ class CppFlexAttentionTemplate(CppTemplate):
stack.enter_context(
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
)
if (
query.data.data.layout.size[2] == 1
and self.partition_size % self.kv_block_size == 0
):
# use flash decoding when qSize == 1
return self._template_from_string(FLEX_DECODING_TEMPLATE).render(
**options
)
else:
# use flash attention when qSize > 1
return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(
**options
)
FLEX_TEMPLATE = self.choose_flex_template(query, num_threads)
return self._template_from_string(FLEX_TEMPLATE).render(
**options
)
def codegen_softmax_fusion(self, kernel_name: str):
# TODO: use inductor IR to rewrite those fusions