mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
update
This commit is contained in:
@ -783,8 +783,9 @@ extern "C"
|
||||
const scalar_t* v_data = value;
|
||||
scalar_t* out_data = output;
|
||||
|
||||
auto actual_kvSize = kvSize / batchSize;
|
||||
auto num_partitions =
|
||||
(kvSize + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
(actual_kvSize + PARTITION_SIZE - 1) / PARTITION_SIZE;
|
||||
|
||||
// Allocate temp buf (accumulate type)
|
||||
int64_t _accum_buff_size =
|
||||
@ -808,6 +809,7 @@ extern "C"
|
||||
auto tmp_out_strideH = num_partitions * headSize_v;
|
||||
auto tmp_out_strideS = headSize_v;
|
||||
|
||||
// TODO: For GQA, the parallelism dim of q_num_head can be changed to kv_num_head
|
||||
// Attention loop
|
||||
at::parallel_for(0, batchSize * num_head * num_partitions, 1, [&](int64_t begin, int64_t end) {
|
||||
int64_t i = 0, j = 0, partition_id = 0;
|
||||
@ -820,7 +822,8 @@ extern "C"
|
||||
: nullptr;
|
||||
|
||||
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||
auto partition_start = partition_id * PARTITION_SIZE;
|
||||
auto n_offset = i * actual_kvSize;
|
||||
auto partition_start = n_offset + partition_id * PARTITION_SIZE;
|
||||
auto partition_end =
|
||||
std::min(partition_start + PARTITION_SIZE, kvSize);
|
||||
|
||||
@ -930,9 +933,9 @@ extern "C"
|
||||
{{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }}
|
||||
}
|
||||
}
|
||||
{%- endif %}
|
||||
token_num += cur_kvSplitSize;
|
||||
}
|
||||
{%- endif %}
|
||||
|
||||
|
||||
// 2) calculate the max and exp_sum for this partition
|
||||
@ -1044,13 +1047,11 @@ extern "C"
|
||||
auto global_exp_sum = 0.0;
|
||||
|
||||
// Calculate the global max and exp_sum for this head
|
||||
for (auto partition_id = 0; partition_id < num_partitions;
|
||||
partition_id++) {
|
||||
auto max_logit = max_logits_ptr
|
||||
[i * max_logits_strideN +
|
||||
j * max_logits_strideH + partition_id];
|
||||
global_max = std::max(global_max, max_logit);
|
||||
}
|
||||
global_max = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
|
||||
max_logits_ptr + i * max_logits_strideN
|
||||
+ j * max_logits_strideH,
|
||||
num_partitions);
|
||||
|
||||
// Update the partition 0 result with the global max
|
||||
auto partition0_out_start =
|
||||
@ -1058,7 +1059,7 @@ extern "C"
|
||||
auto max_logit0 = max_logits_ptr
|
||||
[i * max_logits_strideN + j * max_logits_strideH];
|
||||
float exp_val = std::exp(max_logit0 - global_max);
|
||||
global_exp_sum +=
|
||||
global_exp_sum =
|
||||
exp_sum_ptr[i * exp_sum_strideN + j * exp_sum_strideH] *
|
||||
exp_val;
|
||||
at::vec::map<accum_t>(
|
||||
@ -1400,6 +1401,29 @@ class CppFlexAttentionTemplate(CppTemplate):
|
||||
def apply_score_mod(self, score, b, h, q_idx, kv_idx):
|
||||
return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item()
|
||||
|
||||
def choose_flex_template(
|
||||
self,
|
||||
query: ir.Buffer,
|
||||
num_threads,
|
||||
):
|
||||
# choose from FLEX_ATTENTION or FLEX_DECODING
|
||||
FLEX_TEMPLATE = FLEX_ATTENTION_TEMPLATE
|
||||
q_batch_size = query.data.data.layout.size[0]
|
||||
q_num_heads = query.data.data.layout.size[1]
|
||||
q_seq_len = query.data.data.layout.size[2]
|
||||
if all(
|
||||
sympy.sympify(val).is_number
|
||||
for val in [q_batch_size, q_num_heads, q_seq_len, num_threads]
|
||||
):
|
||||
# if static shape
|
||||
if (
|
||||
self.partition_size % self.kv_block_size == 0
|
||||
and q_seq_len == 1
|
||||
and num_threads > q_batch_size * q_num_heads
|
||||
):
|
||||
FLEX_TEMPLATE = FLEX_DECODING_TEMPLATE
|
||||
return FLEX_TEMPLATE
|
||||
|
||||
def render( # type: ignore[override,return]
|
||||
self,
|
||||
kernel,
|
||||
@ -1464,19 +1488,10 @@ class CppFlexAttentionTemplate(CppTemplate):
|
||||
stack.enter_context(
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
|
||||
)
|
||||
if (
|
||||
query.data.data.layout.size[2] == 1
|
||||
and self.partition_size % self.kv_block_size == 0
|
||||
):
|
||||
# use flash decoding when qSize == 1
|
||||
return self._template_from_string(FLEX_DECODING_TEMPLATE).render(
|
||||
**options
|
||||
)
|
||||
else:
|
||||
# use flash attention when qSize > 1
|
||||
return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(
|
||||
**options
|
||||
)
|
||||
FLEX_TEMPLATE = self.choose_flex_template(query, num_threads)
|
||||
return self._template_from_string(FLEX_TEMPLATE).render(
|
||||
**options
|
||||
)
|
||||
|
||||
def codegen_softmax_fusion(self, kernel_name: str):
|
||||
# TODO: use inductor IR to rewrite those fusions
|
||||
|
||||
Reference in New Issue
Block a user