Compare commits

..

4 Commits

Author SHA1 Message Date
887cadfb12 [inductor] template registry
\# why

find templates by uid and enable information passing about which
templates is used

\# what

- expand registry for heuristics to also track templates
- register all templates on creation
- register extern kernels (pseudo templates) on creation

- this also now enforces that uid is unique

\# testing

existing tests

ghstack-source-id: 26767ccc1fe526d90d2dff78e7834e0fda7c5b67
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163726
2025-09-24 10:58:56 -07:00
c261c71f3e Simplify _compute_local_shape_and_global_offset and make it SPMD. (#163344)
There is only one substantive change: the branch on
`global_offset[shard_dim] <= local_offset[shard_dim]`
is removed because it is unnecessary: you can always treat the
first shard uniformly with the rest of the shards, because your
global offset is guaranteed to be zero in this case anyway.

I also switch the shard_size case to sym_ite, to make it possible
for LocalTensor to deal with the MPMD-ness here, but it's equivalent
to the old if-then-else.

I tried to rewrite the comments to be more clear what is going on
algorithmically here.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163344
Approved by: https://github.com/albanD, https://github.com/zpcore, https://github.com/tianyu-l
2025-09-24 02:24:09 +00:00
e2ce79e4cc [Flex] Fix silent correctness w/ backpropping grads (#163677)
Fixes #https://github.com/pytorch/pytorch/issues/162228

# Summary

Majority of our tests are only compiling flex-attention in isolation. This means that for fake tensor propagation the input primals and all captured buffers dont do any intermediate computation below autograd.  As a result result the by happen chance match the `require_grad`ness of the eager implementation and this check  will pass. However if score_mod is a the result of some other intermediate fake tensor prop then it is not guaranteed to have accurate req_gradness, which was happening here.

TLDR is that this was a boot and suspenders that was actually harmful and we should just let the joint graph handle creating the correct joint graph

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163677
Approved by: https://github.com/ydwu4
2025-09-24 02:12:19 +00:00
be6c127927 [AOTI] Pass comments from metadata to the autotune block (#163600)
Summary: When generating Triton kernels in the compile-time autotune blocks, it will be useful to generate source information as code comments. Previously we ignore these comments for autotune code blocks because the generated main output code will contain the same information, but it won't work if the generated autotune code crashes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163600
Approved by: https://github.com/yushangdi
2025-09-24 02:01:59 +00:00
11 changed files with 218 additions and 539 deletions

View File

@ -6636,6 +6636,35 @@ class TestLearnableBiases(InductorTestCase):
assert bias.grad, "No gradient computed for bias"
assert torch.any(bias.grad != 0), "Gradient for bias is 0"
@skip_on_cpu
def test_backprop_error_case(self, device):
@torch.compile()
def test(x, y):
# Materialize a bias matrix
B, L, device = x.shape[0], x.shape[1], x.device
b = torch.arange(B, device=device, dtype=torch.long).view(B, 1, 1)
q_idx = torch.arange(L, device=device, dtype=torch.long).view(1, L, 1)
kv_idx = torch.arange(L, device=device, dtype=torch.long).view(1, 1, L)
bias_mat = y[b, q_idx] + y[b, kv_idx] # (B, L, L)
# Dummy score_mod retrieving bias values
def score_mod(score, b, h, q_idx, kv_idx):
return score + bias_mat[b, q_idx, kv_idx]
x_ = x[:, :, None].repeat(1, 1, 16, 1)
# torch._dynamo.graph_break()
return flex_attention(x_, x_, x_, score_mod=score_mod)
B, L, D = 2, 16, 64
x = torch.randn(B, L, D, device=device, requires_grad=True)
y = torch.randn(B, L, device=device, requires_grad=True)
_ = test(x, y).mean().backward()
assert x.grad.norm() > 0
assert y.grad.norm() > 0
@skip_on_cpu
@common_utils.parametrize(
"params", get_params(device_configs["cuda"].dtypes), name_fn=lambda x: f"{x}"

View File

@ -27,7 +27,6 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, with_tf3
from torch.testing._internal.common_device_type import (
flex_attention_supported_platform as supported_platform,
instantiate_device_type_tests,
skipCUDAIf,
skipXPUIf,
)
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS
@ -86,7 +85,7 @@ else:
LONG_COMPILATION_ON_CPU = True
test_dtypes = (
[torch.float32, torch.bfloat16, torch.float16]
[torch.float32, torch.bfloat16]
if torch.backends.mkldnn.is_available()
and torch.ops.mkldnn._is_mkldnn_bf16_supported()
else [torch.float32]
@ -578,7 +577,6 @@ class TestFlexDecoding(InductorTestCase):
v: Tensor,
dtype: torch.dtype = torch.float16,
block_mask: Optional[BlockMask] = None,
kernel_options: Optional[dict] = None,
device="cuda",
):
Q_B, Q_H, KV_H = q.shape[0], q.shape[1], k.shape[1]
@ -609,7 +607,6 @@ class TestFlexDecoding(InductorTestCase):
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
else:
compiled_lse = None
@ -621,7 +618,6 @@ class TestFlexDecoding(InductorTestCase):
block_mask=converted_block_mask,
score_mod=converted_score_mod,
enable_gqa=(not Q_H == KV_H),
kernel_options=kernel_options,
)
return compiled_out, compiled_lse
@ -638,7 +634,6 @@ class TestFlexDecoding(InductorTestCase):
KV_S: int = S,
V_D: int = D,
block_mask: Optional[BlockMask] = None,
kernel_options: Optional[dict] = None,
device="cuda",
):
assert Q_H % KV_H == 0
@ -675,14 +670,7 @@ class TestFlexDecoding(InductorTestCase):
ref_out, ref_lse = sdpa_partial(q_ref, k_ref, v_ref, return_lse=True)
compiled_out, compiled_lse = self.run_paged_attention(
score_mod,
q,
k,
v,
dtype,
block_mask,
device=device,
kernel_options=kernel_options,
score_mod, q, k, v, dtype, block_mask, device
)
self._check_out(
@ -749,7 +737,7 @@ class TestFlexDecoding(InductorTestCase):
block_mask = create_block_mask(noop_mask, Q_B, 1, Q_S, KV_S, device=device)
compiled_out, _ = self.run_paged_attention(
score_mod, q, k, v, dtype, block_mask, device=device
score_mod, q, k, v, dtype, block_mask, device
)
self._check_out(
@ -1582,23 +1570,6 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
):
flex_attention(query, key, value, _identity)
@supported_platform
@skipCUDAIf(True, "Not supported on CUDA")
@skipXPUIf(True, "Not supported on XPU")
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("partition_size", [64, 128, 256, 1024])
def test_flash_decoding_partition_size(self, device, dtype, partition_size):
def score_mod(score, b, h, m, n):
return score * 2
self.run_test_with_paged_attention(
score_mod,
dtype,
KV_S=64,
device=device,
kernel_options={"PARTITION_SIZE": partition_size},
)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune(self, device):

View File

@ -1259,7 +1259,7 @@ def flex_attention_backward_fake_tensor_mode(
[
(
torch.empty_like(buffer, memory_format=torch.contiguous_format)
if isinstance(buffer, torch.Tensor) and buffer.requires_grad
if isinstance(buffer, torch.Tensor)
else None
)
for buffer in score_mod_other_buffers

View File

@ -2411,6 +2411,10 @@ class KernelTemplate:
def __init__(self, name: str, hash: Optional[str] = None) -> None:
self.name = name
self._hash = hash
# Register this template instance in the global registry
from ..template_heuristics.registry import register_template
register_template(self)
@property
def uid(self) -> str:

View File

@ -208,7 +208,7 @@ ALLOCATE_BUFFER = r"""
{{buffer_dtype}}* {{buffer_name}} = ({{buffer_dtype}}*){{buffer_name}}_data_ptr;
"""
INIT_PARAMS = r"""
FLEX_ATTENTION_TEMPLATE = r"""
{{template.header().getvalue()}}
#include <ATen/native/cpu/utils.h>
#include <ATen/native/CPUBlas.h>
@ -225,18 +225,16 @@ extern "C"
{{kernel.def_kernel(inputs=kernel_args, outputs={"output": output}, extra_sizevars=template.extra_sizevars)}}
{
{{ kernel.maybe_codegen_profile() }}
int64_t qBlockSize = {{qBlockSize}};
int64_t kvBlockSize = {{kvBlockSize}};
int64_t num_thread = {{num_thread}};
// dtypes
// dtypes of kernel and internal buffers
using scalar_t = {{kernel.dtype(query)}};
constexpr bool is_reduced_type = c10::is_reduced_floating_point_v<scalar_t>;
using accum_t = at::opmath_type<{{kernel.dtype(query)}}>;
using Vec = at::vec::Vectorized<accum_t>;
accum_t scaling_factor = {{scale}};
// sizes
int64_t qBlockSize = {{qBlockSize}};
int64_t kvBlockSize = {{kvBlockSize}};
int64_t num_thread = {{num_thread}};
int64_t batchSize = {{kernel.size(query, 0)}};
int64_t qSize = {{kernel.size(query, 1)}};
int64_t num_head = {{kernel.size(query, 2)}};
@ -257,18 +255,6 @@ extern "C"
int64_t gqa_shards_kvi = num_head / num_head_kvi;
int64_t bs_shards_kvi = batchSize / batchSize_kvi;
int64_t kvSize = {{kernel.size(key, 1)}};
int64_t qSplitSize = qBlockSize;
int64_t kvSplitSize = kvBlockSize;
qSplitSize = qSplitSize > qSize ? qSize : qSplitSize;
kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize;
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
// Strides
int64_t kviStrideB = {{kernel.stride(kv_indices, 0)}};
int64_t kviStrideH = {{kernel.stride(kv_indices, 1)}};
int64_t kviStrideQ = {{kernel.stride(kv_indices, 2)}};
@ -290,6 +276,7 @@ extern "C"
auto kv_num_blocks_data = kv_num_blocks;
auto kv_indices_data = kv_indices;
// Strides
int64_t qStrideB = {{kernel.stride(query, 0)}};
int64_t qStrideM = {{kernel.stride(query, 1)}};
int64_t qStrideH = {{kernel.stride(query, 2)}};
@ -303,15 +290,18 @@ extern "C"
int64_t oStrideM = {{kernel.stride(output, 2)}};
int64_t oStrideH = {{kernel.stride(output, 1)}};
// Inputs/outputs buffers
const scalar_t* q_data = query;
const scalar_t* k_data = key;
const scalar_t* v_data = value;
scalar_t* out_data = output;
int64_t kvSize = {{kernel.size(key, 1)}};
"""
int64_t qSplitSize = qBlockSize;
int64_t kvSplitSize = kvBlockSize;
qSplitSize = qSplitSize > qSize ? qSize : qSplitSize;
kvSplitSize = kvSplitSize > kvSize ? kvSize : kvSplitSize;
int64_t qSlice = (qSize + qSplitSize - 1) / qSplitSize;
int64_t kvSlice = (kvSize + kvSplitSize - 1) / kvSplitSize;
int64_t kvTail = (kvSize - 1) % kvSplitSize + 1;
FLEX_ATTENTION_TEMPLATE = r"""
bool need_pack = false;
// Whether pack is needed for BFloat16/Half
if (is_reduced_type) {
@ -345,6 +335,12 @@ FLEX_ATTENTION_TEMPLATE = r"""
/* qk_sum */ qSplitSize +
/* dst */ qSplitSize * headSize_v;
// Inputs/outputs buffers
const scalar_t* q_data = query;
const scalar_t* k_data = key;
const scalar_t* v_data = value;
scalar_t* out_data = output;
// Buffers to store accum results, padding query and transpose/packing key/value
{{template.codegen_allocate_buffer("buf_data", "accum_t", "num_thread*_size_per_thread")}}
{{template.codegen_allocate_buffer("buf_reduced_data", "scalar_t", "num_thread*qSplitSize*ekvSplitSize")}}
@ -687,389 +683,8 @@ FLEX_ATTENTION_TEMPLATE = r"""
}
"""
FLEX_DECODING_TEMPLATE = r"""
int64_t PARTITION_SIZE = {{partition_size}};
// Check if score / mask mod dependent on batch_size / num_head
// Go into a fast path if independent
bool bs_head_independent_mod = true;
int64_t first_num_kvblocks = kv_num_blocks[0];
int64_t first_full_num_kvblocks = full_kv_num_blocks[0];
for (const auto& b : c10::irange(batchSize_kvi)) {
for (const auto& h : c10::irange(num_head_kvi)) {
if (*(kv_num_blocks + b * num_kviStrideB + h * num_kviStrideH) != first_num_kvblocks
|| *(full_kv_num_blocks + b * full_num_kviStrideB + h * full_num_kviStrideH) != first_full_num_kvblocks) {
bs_head_independent_mod = false;
break;
}
}
}
int64_t num_kvblocks_per_seq = kv_num_blocks[0] + full_kv_num_blocks[0];
int64_t num_kvblocks_per_partition = PARTITION_SIZE / kvBlockSize;
int64_t num_partitions = (num_kvblocks_per_seq + num_kvblocks_per_partition - 1) / num_kvblocks_per_partition;
if (!bs_head_independent_mod) {
num_partitions =
(kvSize + PARTITION_SIZE - 1) / PARTITION_SIZE;
}
// Allocate temp buf (accumulate type)
int64_t _accum_buff_size =
/* max_logits_ptr */ batchSize * num_head * num_partitions +
/* exp_sum_ptr */ batchSize * num_head * num_partitions +
/* tmp_out_ptr */ batchSize * num_head * num_partitions * headSize_v +
/* logits_ptrs */ num_thread * PARTITION_SIZE;
{{template.codegen_allocate_buffer("buf_data", "accum_t", "_accum_buff_size")}}
accum_t* max_logits_ptr = buf_data;
accum_t* exp_sum_ptr = max_logits_ptr + batchSize * num_head * num_partitions;
accum_t* tmp_out_ptr = exp_sum_ptr + batchSize * num_head * num_partitions;
accum_t* logits_ptrs = tmp_out_ptr + batchSize * num_head * num_partitions * headSize_v;
{{template.codegen_allocate_buffer("logits_reduced_ptrs", "scalar_t", "num_thread * PARTITION_SIZE")}}
auto max_logits_strideN = num_head * num_partitions;
auto max_logits_strideH = num_partitions;
auto exp_sum_strideN = num_head * num_partitions;
auto exp_sum_strideH = num_partitions;
auto tmp_out_strideN = num_head * num_partitions * headSize_v;
auto tmp_out_strideH = num_partitions * headSize_v;
auto tmp_out_strideS = headSize_v;
// Attention loop
at::parallel_for(0, batchSize * num_head * num_partitions, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0, partition_id = 0;
at::native::data_index_init(begin, i, batchSize, j, num_head, partition_id, num_partitions);
int ompIdx = at::get_thread_num();
accum_t* logits = logits_ptrs + ompIdx * PARTITION_SIZE;
scalar_t* logits_reduced =
is_reduced_type
? logits_reduced_ptrs + ompIdx * PARTITION_SIZE
: nullptr;
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
auto kvblock_offset = num_kvblocks_per_partition * partition_id;
auto i_kvi = is_broadcast_bs_kvi ? i/bs_shards_kvi : i;
auto j_kvi = is_broadcast_head_kvi ? j/gqa_shards_kvi : j;
auto kv_logical_num_data = kv_num_blocks_data + i_kvi * num_kviStrideB +
j_kvi * num_kviStrideH;
int64_t kv_indice_num = *kv_logical_num_data;
std::vector<int64_t> kv_indice_list(kv_indice_num);
for(int64_t kv_i = 0; kv_i < kv_indice_num; kv_i++){
auto kv_logical_data = kv_indices_data + i_kvi * kviStrideB +
j_kvi * kviStrideH + kv_i;
kv_indice_list[kv_i] = *kv_logical_data;
}
{%- if has_full_kv_block %}
auto full_kv_logical_num_data = full_kv_num_blocks_data + i_kvi * num_kviStrideB +
j_kvi * num_kviStrideH;
int64_t full_kv_indice_num = *full_kv_logical_num_data;
std::vector<int64_t> full_kv_indice_list(full_kv_indice_num);
for(int64_t kv_i = 0; kv_i < full_kv_indice_num; kv_i++){
auto full_kv_logical_data = full_kv_indices_data + i_kvi * full_kviStrideB +
j_kvi * full_kviStrideH + kv_i;
full_kv_indice_list[kv_i] = *full_kv_logical_data;
}
{%- endif %}
int64_t cur_qSplitSize = 1;
auto i_kv = is_broadcast_bs_kv ? i/bs_shards : i;
auto j_kv = is_broadcast_head_kv ? j/gqa_shards : j;
accum_t* tmp_out = tmp_out_ptr + i * tmp_out_strideN +
j * tmp_out_strideH + partition_id * tmp_out_strideS;
// Initialize logits
{{kernel.kernel_name}}_fill_stub(logits,
static_cast<accum_t>(0), PARTITION_SIZE);
if (is_reduced_type) {
{{kernel.kernel_name}}_fill_stub(logits_reduced,
static_cast<scalar_t>(0), PARTITION_SIZE);
}
// 1) calculate the matmul(query, key) for this partition
int64_t token_num = 0;
{%- if has_full_kv_block %}
int64_t n_idx_start = kvblock_offset;
int64_t n_idx_end = std::min(kvblock_offset + num_kvblocks_per_partition, kv_indice_num + full_kv_indice_num);
if (!bs_head_independent_mod) {
n_idx_start = 0;
n_idx_end = kv_indice_num + full_kv_indice_num;
}
for (int64_t n_idx : c10::irange(n_idx_start, n_idx_end)) {
auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize;
{%- else %}
int64_t n_idx_start = kvblock_offset;
int64_t n_idx_end = std::min(kvblock_offset + num_kvblocks_per_partition, kv_indice_num);
if (!bs_head_independent_mod) {
n_idx_start = 0;
n_idx_end = kv_indice_num;
}
for (int64_t n_idx : c10::irange(n_idx_start, n_idx_end)) {
auto n = kv_indice_list[n_idx]*kvSplitSize;
{%- endif %}
if (!bs_head_independent_mod
&& (n < partition_id * PARTITION_SIZE
|| n >= std::min(partition_id * PARTITION_SIZE + PARTITION_SIZE, kvSize))) {
continue;
}
auto cur_n = n/kvSplitSize;
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
auto k_addr =
k_data + i_kv * kStrideB + j_kv * kStrideH + n * kStrideN;
{{kernel.kernel_name}}_kernel_micro_gemm_transpose_b<false>(
q_data + i * qStrideB + j * qStrideH,
k_addr,
logits + token_num,
cur_qSplitSize,
cur_kvSplitSize,
headSize,
qStrideM,
kStrideN,
cur_kvSplitSize);
{{kernel.kernel_name}}_mul_scale_kernel<accum_t>(logits + token_num, scaling_factor, cur_qSplitSize*cur_kvSplitSize);
{%- if score_mod and mask_mod %}
// TODO: reduce the number of calls of q_idx and kv_idx initialization
std::vector<int64_t> q_idx(cur_qSplitSize);
for (int64_t i = 0; i < cur_qSplitSize; ++i) {
q_idx[i] = i;
}
std::vector<int64_t> kv_idx(cur_kvSplitSize);
for (int64_t i = 0; i < cur_kvSplitSize; ++i) {
kv_idx[i] = n + i;
}
std::vector<int64_t> b_idx = {i};
std::vector<int64_t> h_idx = {j};
accum_t* in_ptr0 = logits + token_num;
const auto in_ptr1 = b_idx.data();
const auto in_ptr2 = h_idx.data();
const auto in_ptr3 = q_idx.data();
const auto in_ptr4 = kv_idx.data();
// apply score mod function
{
{{ template.generate_other_buffer("score_others", 0, "len_score_other", kernel.args) }}
accum_t* out_ptr{{score_buf_idx}} = in_ptr0;
{{ template.modification(score_mod, score_buf_name, score_buf_idx)|indent(12, false) }}
}
if ((std::find(kv_indice_list.begin(), kv_indice_list.end(), cur_n) != kv_indice_list.end()) ){
// Apply block mask, fill unused with -inf
{
{{ template.generate_other_buffer("mask_others", -1, "len_mask_other", kernel.args) }}
accum_t* out_ptr{{mask_buf_idx}} = in_ptr0;
{{ template.modification(mask_mod, mask_buf_name, mask_buf_idx)|indent(12, false) }}
}
}
{%- endif %}
token_num += cur_kvSplitSize;
}
// 2) calculate the max and exp_sum for this partition
auto partition_max = -std::numeric_limits<float>::infinity();
{{kernel.kernel_name}}_mul_reduce_max_fusion_kernel(
logits,
static_cast<accum_t>(1),
token_num,
logits,
partition_max);
if (partition_max == -std::numeric_limits<float>::infinity()) {
partition_max = 0;
}
max_logits_ptr[i * max_logits_strideN +
j * max_logits_strideH + partition_id] =
partition_max;
{{kernel.kernel_name}}_exp_reduce_sum_fusion_kernel(
logits,
token_num,
{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced),
partition_max);
exp_sum_ptr[i * exp_sum_strideN +
j * exp_sum_strideH + partition_id] = partition_max;
// 3) calculate the matmul(exp(logits-partition_max), value) for this
// partition, need to divide the global exp_sum in the final result.
token_num = 0;
bool skipped_partition = true;
{%- if has_full_kv_block %}
n_idx_start = kvblock_offset;
n_idx_end = std::min(kvblock_offset + num_kvblocks_per_partition, kv_indice_num + full_kv_indice_num);
if (!bs_head_independent_mod) {
n_idx_start = 0;
n_idx_end = kv_indice_num + full_kv_indice_num;
}
for (int64_t n_idx : c10::irange(n_idx_start, n_idx_end)) {
auto n = n_idx < kv_indice_num ? kv_indice_list[n_idx]*kvSplitSize : full_kv_indice_list[n_idx - kv_indice_num]*kvSplitSize;
{%- else %}
n_idx_start = kvblock_offset;
n_idx_end = std::min(kvblock_offset + num_kvblocks_per_partition, kv_indice_num);
if (!bs_head_independent_mod) {
n_idx_start = 0;
n_idx_end = kv_indice_num;
}
for (int64_t n_idx : c10::irange(n_idx_start, n_idx_end)) {
auto n = kv_indice_list[n_idx]*kvSplitSize;
{%- endif %}
if (!bs_head_independent_mod
&& (n < partition_id * PARTITION_SIZE
|| n >= std::min(partition_id * PARTITION_SIZE + PARTITION_SIZE, kvSize))) {
continue;
}
skipped_partition = false;
int64_t cur_kvSplitSize = std::min(kvSplitSize, kvSize - n);
auto v_addr =
v_data + i_kv * vStrideB + j_kv * vStrideH + n * vStrideN;
// Fallback Half brgemm is slower than micro gemm
if constexpr (!std::is_same_v<scalar_t, at::Half>) {
at::native::cpublas::brgemm(
cur_qSplitSize,
headSize_v,
cur_kvSplitSize,
cur_kvSplitSize,
vStrideN,
headSize_v,
token_num > 0,
{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced) + token_num,
v_addr,
tmp_out,
false);
} else {
if (token_num > 0) {
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(true)>(
{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced) + token_num,
v_addr,
tmp_out,
cur_qSplitSize,
headSize_v,
cur_kvSplitSize,
cur_kvSplitSize,
vStrideN,
headSize_v);
} else {
{{kernel.kernel_name}}_kernel_micro_gemm<static_cast<bool>(false)>(
{{kernel.kernel_name}}_conditional_data_ptr(logits, logits_reduced) + token_num,
v_addr,
tmp_out,
cur_qSplitSize,
headSize_v,
cur_kvSplitSize,
cur_kvSplitSize,
vStrideN,
headSize_v);
}
}
token_num += cur_kvSplitSize;
}
if (skipped_partition) {
{{kernel.kernel_name}}_fill_stub(tmp_out,
static_cast<accum_t>(0), headSize_v);
}
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head, partition_id, num_partitions);
}
if constexpr (!std::is_same_v<scalar_t, at::Half>) {
at::native::cpublas::brgemm_release();
}
});
// Calculate the final output
at::parallel_for(0, batchSize * num_head, 1, [&](int64_t begin, int64_t end) {
int64_t i = 0, j = 0;
at::native::data_index_init(begin, i, batchSize, j, num_head);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
auto global_max = -std::numeric_limits<float>::infinity();
auto global_exp_sum = 0.0;
// Calculate the global max and exp_sum for this head
global_max = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
max_logits_ptr + i * max_logits_strideN
+ j * max_logits_strideH,
num_partitions);
// Update the partition 0 result with the global max
auto partition0_out_start =
tmp_out_ptr + i * tmp_out_strideN + j * tmp_out_strideH;
auto max_logit0 = max_logits_ptr
[i * max_logits_strideN + j * max_logits_strideH];
float exp_val = std::exp(max_logit0 - global_max);
global_exp_sum =
exp_sum_ptr[i * exp_sum_strideN + j * exp_sum_strideH] *
exp_val;
at::vec::map<accum_t>(
[exp_val](Vec x) { return x * Vec(exp_val); },
partition0_out_start,
partition0_out_start,
headSize_v);
// Accumulate the partition 1 to partition n result into partition 0
if (num_partitions > 1) {
for (auto partition_id = 1; partition_id < num_partitions;
partition_id++) {
auto tmp_out_start = partition0_out_start + partition_id * tmp_out_strideS;
auto max_logit = max_logits_ptr
[i * max_logits_strideN + j * max_logits_strideH +
partition_id];
auto exp_sum = exp_sum_ptr
[i * exp_sum_strideN + j * exp_sum_strideH +
partition_id];
exp_val = std::exp(max_logit - global_max);
global_exp_sum += exp_sum * exp_val;
at::vec::map2<accum_t>(
[exp_val](Vec a, Vec b) { return a + Vec(exp_val) * b; },
partition0_out_start,
partition0_out_start,
tmp_out_start,
headSize_v);
}
}
// Rescale the partition 0 result with global exp_sum
// Sum for full masked out rows are 0, we set them to 1
// in order to avoid NaNs in the output and instead set fully
// masked out rows to 0
global_exp_sum = global_exp_sum == 0 ? 1 : global_exp_sum;
float sum_reciprocal = 1.0 / global_exp_sum;
// copy the partition 0 result into output
at::vec::map<scalar_t>(
[sum_reciprocal](Vec x) { return x * Vec(sum_reciprocal); },
out_data + i * oStrideB + j * oStrideH,
partition0_out_start,
headSize_v);
// Move to the next query
at::native::data_index_step(i, batchSize, j, num_head);
}
});
}
"""
class CppFlexAttentionTemplate(CppTemplate):
"""
CPP template based FlexAttention CPP Template.
This class supports generation of C++ code for broad attention variants,
with applying Flash Attention and Flash Decoding. It enables template-based
code synthesis according to user-defined score and mask modifications
and configuration settings.
"""
def __init__(
self,
input_nodes,
@ -1079,7 +694,6 @@ class CppFlexAttentionTemplate(CppTemplate):
mask_mod,
kv_block_size,
q_block_size,
partition_size,
has_other_buffer,
no_full_kv_block,
fake_buffers,
@ -1111,7 +725,6 @@ class CppFlexAttentionTemplate(CppTemplate):
self.mask_buf_idx = get_idx(self.mask_buf_name) if self.mask_buf_name else None
self.kv_block_size = kv_block_size
self.q_block_size = q_block_size
self.partition_size = partition_size
self.has_other_buffer = has_other_buffer
self.no_full_kv_block = no_full_kv_block
self.other_buffer_input_offset = 2
@ -1308,7 +921,6 @@ class CppFlexAttentionTemplate(CppTemplate):
mask_mod,
kv_block_size,
q_block_size,
partition_size,
has_other_buffer,
no_full_kv_block,
fake_buffers,
@ -1334,7 +946,6 @@ class CppFlexAttentionTemplate(CppTemplate):
mask_mod=mask_mod,
kv_block_size=kv_block_size,
q_block_size=q_block_size,
partition_size=partition_size,
has_other_buffer=has_other_buffer,
no_full_kv_block=no_full_kv_block,
fake_buffers=fake_buffers,
@ -1349,32 +960,6 @@ class CppFlexAttentionTemplate(CppTemplate):
def apply_score_mod(self, score, b, h, q_idx, kv_idx):
return self.score_mod.graph_module(score, b, h, q_idx, kv_idx).item()
def choose_flex_template(
self,
query: ir.Buffer,
key: ir.Buffer,
num_threads,
):
# choose from FLEX_ATTENTION or FLEX_DECODING
FLEX_TEMPLATE = FLEX_ATTENTION_TEMPLATE
q_batch_size, q_num_heads, q_seq_len, _ = query.data.data.layout.size # type: ignore[attr-defined]
k_seq_len = key.data.data.layout.size[2] # type: ignore[attr-defined]
if all(
sympy.sympify(val).is_number
for val in [q_batch_size, q_num_heads, q_seq_len, k_seq_len, num_threads]
):
# if static shape, FLEX_DECODING will be chosen with these conditions:
# 1) partition size is multiple of kv block size, so each partition has several blocks
# 2) decoding scenario: q seq length is 1
# 3) The actual k seq length (k_seq_len / q_batch_size) is large enough
if (
self.partition_size % self.kv_block_size == 0
and q_seq_len == 1
and k_seq_len / q_batch_size >= max(self.partition_size * 2, 512)
):
FLEX_TEMPLATE = FLEX_DECODING_TEMPLATE
return FLEX_TEMPLATE
def render( # type: ignore[override,return]
self,
kernel,
@ -1432,17 +1017,13 @@ class CppFlexAttentionTemplate(CppTemplate):
mask_buf_name=self.mask_buf_name,
score_buf_idx=self.score_buf_idx,
mask_buf_idx=self.mask_buf_idx,
partition_size=self.partition_size,
)
with contextlib.ExitStack() as stack:
for buf in self.fake_buffers:
stack.enter_context(
patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
)
FLEX_TEMPLATE = self.choose_flex_template(query, key, num_threads)
return self._template_from_string(INIT_PARAMS + FLEX_TEMPLATE).render(
**options
)
return self._template_from_string(FLEX_ATTENTION_TEMPLATE).render(**options)
def codegen_softmax_fusion(self, kernel_name: str):
# TODO: use inductor IR to rewrite those fusions

View File

@ -2155,6 +2155,10 @@ class PythonWrapperCodegen(CodeGen):
def _format_kernel_definition(
kernel_name: str, kernel_body: str, metadata: Optional[str] = None
):
if config.triton.autotune_at_compile_time and metadata:
# Generating autotune block
# Need to replace C++ comment starter with Python comment starter
metadata = re.sub(r"^// ", "# ", metadata, flags=re.MULTILINE)
metadata_comment = f"{metadata}\n" if metadata else ""
body = f"\n\n{metadata_comment}{kernel_name} = {kernel_body}"
return body
@ -2168,9 +2172,8 @@ class PythonWrapperCodegen(CodeGen):
cpp_definition: Optional[str] = None,
):
if config.triton.autotune_at_compile_time:
# Skip inserting comments for the autotune block as they may contain cpp style comments
body = self._format_kernel_definition(
kernel_name, kernel_body, metadata=None
kernel_name, kernel_body, metadata=metadata
)
self.kernel_autotune_defs.splice(body)
if V.graph.cpp_wrapper:

View File

@ -289,8 +289,6 @@ def lower_cpu(
# Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards.
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_KV_BLOCK_SIZE)
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.guard_int(SPARSE_Q_BLOCK_SIZE)
# In flash decoding, the partition size of doing the parallelism on KV length dim
PARTITION_SIZE = kernel_options.get("PARTITION_SIZE", 128)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
), (
@ -310,7 +308,6 @@ def lower_cpu(
mask_mod=None if skip_mask_score else mask_graph_buffer,
kv_block_size=SPARSE_KV_BLOCK_SIZE,
q_block_size=SPARSE_Q_BLOCK_SIZE,
partition_size=PARTITION_SIZE,
has_other_buffer=has_other_buffer,
no_full_kv_block=no_full_kv_block,
fake_buffers=fake_buffers,

View File

@ -2104,6 +2104,10 @@ class ExternKernelChoice:
# There is no src hash for ExternKernelChoice in the traditional sense
# so we indicate this by returning None
self.src_hash = None
# Register this template instance in the global registry
from .template_heuristics.registry import register_template
register_template(self)
def to_callable(self):
return getattr(extern_kernels, self.name)

View File

@ -18,6 +18,9 @@ from .base import TemplateConfigHeuristics
if TYPE_CHECKING:
from collections.abc import Iterator
from ..codegen.common import KernelTemplate
from ..select_algorithm import ExternKernelChoice
# Module-wide registry for template heuristics
_TEMPLATE_HEURISTIC_REGISTRY: dict[
@ -27,6 +30,9 @@ _TEMPLATE_HEURISTIC_REGISTRY: dict[
# Manual cache for successful lookups only (fallback instances are not cached)
_HEURISTIC_CACHE: dict[tuple[str, str, str], TemplateConfigHeuristics] = {}
# Template registry for serialization/deserialization of kernel template choices
_TEMPLATE_REGISTRY: dict[str, Union[KernelTemplate, ExternKernelChoice]] = {}
log = logging.getLogger(__name__)
@ -128,12 +134,68 @@ def get_template_heuristic(
def clear_registry() -> None:
"""
Clear all registered template heuristics.
Clear all registered template heuristics and templates.
This is primarily useful for testing purposes to ensure a clean state.
"""
_TEMPLATE_HEURISTIC_REGISTRY.clear()
_HEURISTIC_CACHE.clear()
_TEMPLATE_REGISTRY.clear()
def register_template(template: Union[KernelTemplate, ExternKernelChoice]) -> None:
"""
Register a template instance in the global template registry.
Args:
template: The template instance (KernelTemplate or ExternKernelChoice) to register
Raises:
AssertionError: If a template with the same UID is already registered
"""
template_uid = template.uid
if template_uid in _TEMPLATE_REGISTRY:
existing_template = _TEMPLATE_REGISTRY[template_uid]
if existing_template is not template:
raise AssertionError(
f"Duplicate template UID '{template_uid}' detected. "
f"Existing: {existing_template}, New: {template}"
)
# Same instance re-registering is OK
return
_TEMPLATE_REGISTRY[template_uid] = template
log.debug("Registered template with UID: %s", template_uid)
def get_template_by_uid(template_uid: str) -> Union[KernelTemplate, ExternKernelChoice]:
"""
Retrieve a template instance by its UID.
Args:
template_uid: The unique identifier of the template
Returns:
The template instance
Raises:
KeyError: If no template with the given UID is found
"""
if template_uid not in _TEMPLATE_REGISTRY:
raise KeyError(
f"Template with UID '{template_uid}' not found. "
f"Registered templates: {list(_TEMPLATE_REGISTRY.keys())}"
)
return _TEMPLATE_REGISTRY[template_uid]
def clear_template_registry() -> None:
"""
Clear all registered templates.
This is primarily useful for testing purposes to ensure a clean state.
"""
_TEMPLATE_REGISTRY.clear()
@contextlib.contextmanager

View File

@ -123,69 +123,92 @@ def _compute_local_shape_and_global_offset(
my_coordinate: Optional[list[int]],
placements: Sequence[Placement],
) -> tuple[tuple[int, ...], tuple[int, ...]]:
ordered_placements = _explicit_order_placements(mesh_shape, placements)
"""
Suppose you have a full tensor with size global_shape, and you have sharded
it according to placements for mesh_shape. This function returns, for a
specific coordinate my_coordinate in the device mesh:
- The size of your local shard WITHOUT padding (i.e., if you have
an uneven split, your size might be smaller than the other entries
in your dim), and
- Where the data for your shard begins, in the full tensor.
This function is fairly simple if your tensor is evenly sharded; the complication
is around uneven splits. There is also some complication for handling StridedShard,
which changes the order you should apply sharding.
"""
if my_coordinate is None:
# if rank not in the mesh, return empty offset
return ((0,), ())
else:
local_shape = list(global_shape)
global_offset = [0] * len(global_shape)
for mesh_dim, placement in ordered_placements:
mesh_dim_size = mesh_shape[mesh_dim]
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
)
shard_size, shard_offset = placement._local_shard_size_and_offset(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[mesh_dim],
)
local_shape[shard_dim] = shard_size
local_offset[shard_dim] = shard_offset
if shard_size == 0:
# Special case to fill in a standardized non-garbage value for the global_offset
# of zero-sized shards. This value is out of bounds of the tensor, so it won't conflict
# with any real offsets. DCP may rely on this value to de-duplicate shards.
global_offset[shard_dim] = global_shape[shard_dim]
else:
# On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
# it means that this dimension has been already sharded in previous placement.
# Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
# Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
if global_offset[shard_dim] <= local_offset[shard_dim]:
global_offset[shard_dim] = local_offset[shard_dim]
else:
global_offset[shard_dim] += local_offset[shard_dim]
# StridedShard implies a non-standard order to apply shards; get the
# correct order to start applying splits
ordered_placements = _explicit_order_placements(mesh_shape, placements)
# NOTE: the offset compute relies on the local shard index and it has no
# problem when strided sharding is not present. To correctly compute, we assume
# that the ``_StridedShard.split_factor`` field encodes how many partitions
# each local tensor will be further split into when sharding on higher mesh
# dimensions. However, this number is only correct if the DTensor is not
# sharded after the strided sharding completes. For example,
# [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements
# where the DTensor's dim-0 is first sharded on device mesh dim-0, then on
# device mesh dim-2, and last on mesh dim-1. We define the
# "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding
# part because strided sharding happens on mesh dim-1 and it was caused by
# the fact that sharding on dim-2 occurred ahead. In this case, there's no
# further sharding after this strided sharding part and ``split_factor``
# correctly encodes the number. Another example is
# [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's
# dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh
# dim-2. This violates our assumption that no further sharding shall occur
# after the strided sharding part and ``split_factor`` won't correctly
# encode the number of further split. So far, the only case where _StridedShard
# placement would appear is FSDP2 + TP on 2D mesh and the above case could only
# happen on mesh of 3 or more dimensions.
# TODO: change this function to correctly address this.
# TODO: this logic can be applied to contiguous sharding as well
return tuple(local_shape), tuple(global_offset)
local_shape = list(global_shape)
# We'll compute the data for where the shard beings on a per-dim basis.
# However, a single dim can be sharded multiple times, so we will end up
# doing a Sum(size*stride) like computation to determine the location of our
# shard for each of the shardings on that dim.
global_offset = [0] * len(global_shape)
for mesh_dim, placement in ordered_placements:
mesh_dim_size = mesh_shape[mesh_dim]
if isinstance(placement, Shard):
shard_dim = placement.dim
assert shard_dim < len(local_shape), (
f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
)
shard_size, shard_offset = placement._local_shard_size_and_offset(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[mesh_dim],
)
local_shape[shard_dim] = shard_size
global_offset[shard_dim] = torch.sym_ite(
shard_size == 0,
# Special case to fill in a standardized non-garbage value for
# the global_offset of zero-sized shards. This value is out
# of bounds of the tensor, so it won't conflict with any real
# offsets. DCP may rely on this value to de-duplicate shards.
# Note that you can end up with zero-size shards that are
# still otherwise in bounds for the tensor (TODO: give an
# example).
global_shape[shard_dim],
# As we successively shard the same dimension, we keep
# advancing our pointer beyond our original offset until we
# get to the final chunk start.
global_offset[shard_dim] + shard_offset,
)
# NOTE: the offset compute relies on the local shard index and it has no
# problem when strided sharding is not present. To correctly compute, we assume
# that the ``_StridedShard.split_factor`` field encodes how many partitions
# each local tensor will be further split into when sharding on higher mesh
# dimensions. However, this number is only correct if the DTensor is not
# sharded after the strided sharding completes. For example,
# [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements
# where the DTensor's dim-0 is first sharded on device mesh dim-0, then on
# device mesh dim-2, and last on mesh dim-1. We define the
# "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding
# part because strided sharding happens on mesh dim-1 and it was caused by
# the fact that sharding on dim-2 occurred ahead. In this case, there's no
# further sharding after this strided sharding part and ``split_factor``
# correctly encodes the number. Another example is
# [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's
# dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh
# dim-2. This violates our assumption that no further sharding shall occur
# after the strided sharding part and ``split_factor`` won't correctly
# encode the number of further split. So far, the only case where _StridedShard
# placement would appear is FSDP2 + TP on 2D mesh and the above case could only
# happen on mesh of 3 or more dimensions.
# TODO: change this function to correctly address this.
# TODO: this logic can be applied to contiguous sharding as well
return tuple(local_shape), tuple(global_offset)
def compute_global_tensor_info(

View File

@ -1970,8 +1970,13 @@ def get_all_device_types() -> list[str]:
return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
# skip since currently flex attention requires at least `avx2` support on CPU.
IS_FLEX_ATTENTION_CPU_PLATFORM_SUPPORTED = (
not torch.xpu.is_available() and not torch.cuda.is_available() and not IS_MACOS
not torch.xpu.is_available()
and not torch.cuda.is_available()
and not IS_MACOS
and torch.cpu._is_avx2_supported()
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
)
IS_FLEX_ATTENTION_XPU_PLATFORM_SUPPORTED = (
torch.xpu.is_available() and torch.utils._triton.has_triton()