From fb1a9f6d8e2e998acf93709eade359044712cf46 Mon Sep 17 00:00:00 2001 From: Valentine233 Date: Mon, 25 Aug 2025 01:58:23 +0000 Subject: [PATCH] update --- .../codegen/cpp_flex_attention_template.py | 63 ++++++++++++------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/torch/_inductor/codegen/cpp_flex_attention_template.py b/torch/_inductor/codegen/cpp_flex_attention_template.py index e87c8f93cbb2..15bd3001c8d1 100644 --- a/torch/_inductor/codegen/cpp_flex_attention_template.py +++ b/torch/_inductor/codegen/cpp_flex_attention_template.py @@ -783,8 +783,9 @@ extern "C" const scalar_t* v_data = value; scalar_t* out_data = output; + auto actual_kvSize = kvSize / batchSize; auto num_partitions = - (kvSize + PARTITION_SIZE - 1) / PARTITION_SIZE; + (actual_kvSize + PARTITION_SIZE - 1) / PARTITION_SIZE; // Allocate temp buf (accumulate type) int64_t _accum_buff_size = @@ -808,6 +809,7 @@ extern "C" auto tmp_out_strideH = num_partitions * headSize_v; auto tmp_out_strideS = headSize_v; + // TODO: For GQA, the parallelism dim of q_num_head can be changed to kv_num_head // Attention loop at::parallel_for(0, batchSize * num_head * num_partitions, 1, [&](int64_t begin, int64_t end) { int64_t i = 0, j = 0, partition_id = 0; @@ -820,7 +822,8 @@ extern "C" : nullptr; for ([[maybe_unused]] auto z : c10::irange(begin, end)) { - auto partition_start = partition_id * PARTITION_SIZE; + auto n_offset = i * actual_kvSize; + auto partition_start = n_offset + partition_id * PARTITION_SIZE; auto partition_end = std::min(partition_start + PARTITION_SIZE, kvSize); @@ -930,9 +933,9 @@ extern "C" {{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }} } } +{%- endif %} token_num += cur_kvSplitSize; } -{%- endif %} // 2) calculate the max and exp_sum for this partition @@ -1044,13 +1047,11 @@ extern "C" auto global_exp_sum = 0.0; // Calculate the global max and exp_sum for this head - for (auto partition_id = 0; partition_id < num_partitions; - partition_id++) { - auto max_logit = max_logits_ptr - [i * max_logits_strideN + - j * max_logits_strideH + partition_id]; - global_max = std::max(global_max, max_logit); - } + global_max = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, + max_logits_ptr + i * max_logits_strideN + + j * max_logits_strideH, + num_partitions); // Update the partition 0 result with the global max auto partition0_out_start = @@ -1058,7 +1059,7 @@ extern "C" auto max_logit0 = max_logits_ptr [i * max_logits_strideN + j * max_logits_strideH]; float exp_val = std::exp(max_logit0 - global_max); - global_exp_sum += + global_exp_sum = exp_sum_ptr[i * exp_sum_strideN + j * exp_sum_strideH] * exp_val; at::vec::map( @@ -1400,6 +1401,29 @@ class CppFlexAttentionTemplate(CppTemplate): def apply_score_mod(self, score, b, h, q_idx, kv_idx): return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item() + def choose_flex_template( + self, + query: ir.Buffer, + num_threads, + ): + # choose from FLEX_ATTENTION or FLEX_DECODING + FLEX_TEMPLATE = FLEX_ATTENTION_TEMPLATE + q_batch_size = query.data.data.layout.size[0] + q_num_heads = query.data.data.layout.size[1] + q_seq_len = query.data.data.layout.size[2] + if all( + sympy.sympify(val).is_number + for val in [q_batch_size, q_num_heads, q_seq_len, num_threads] + ): + # if static shape + if ( + self.partition_size % self.kv_block_size == 0 + and q_seq_len == 1 + and num_threads > q_batch_size * q_num_heads + ): + FLEX_TEMPLATE = FLEX_DECODING_TEMPLATE + return FLEX_TEMPLATE + def render( # type: ignore[override,return] self, kernel, @@ -1464,19 +1488,10 @@ class CppFlexAttentionTemplate(CppTemplate): stack.enter_context( patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf)) ) - if ( - query.data.data.layout.size[2] == 1 - and self.partition_size % self.kv_block_size == 0 - ): - # use flash decoding when qSize == 1 - return self._template_from_string(FLEX_DECODING_TEMPLATE).render( - **options - ) - else: - # use flash attention when qSize > 1 - return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render( - **options - ) + FLEX_TEMPLATE = self.choose_flex_template(query, num_threads) + return self._template_from_string(FLEX_TEMPLATE).render( + **options + ) def codegen_softmax_fusion(self, kernel_name: str): # TODO: use inductor IR to rewrite those fusions