Compare commits

...

1 Commits

Author SHA1 Message Date
258b9329e9 [Inductor][Flex Attention] Add Subtiling Option for Alpha Scaling (#167356)
Summary:


Subtiles the alpha scaling step before the 2nd matmul. Sees an average of ~17% perf gains for sliding window attention with small `D` (ex. 64). https://docs.google.com/spreadsheets/d/18vFASeXtlymD2IwtrNw5o-z7TPM15VbReGiTeQE4_Jg/edit?gid=0#gid=0

Test Plan: `test/inductor/test_flex_attention.py`

Differential Revision: D86172337
2025-11-15 15:33:55 -08:00
5 changed files with 27 additions and 3 deletions

View File

@ -968,7 +968,7 @@ def forward(self, x):
view_3 = torch.ops.aten.view.default(linear_3, [2, 1, 128, 64]); linear_3 = None
sdpa_score0 = self.sdpa_score0
sdpa_mask0 = self.sdpa_mask0
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
flex_attention = torch.ops.higher_order.flex_attention(view_1, view_2, view_3, sdpa_score0, (128, 128, to_3, to_4, to_6, to_7, to_9, to_10, to_12, to_13, 128, 128, sdpa_mask0), 0.125, {'PRESCALE_QK': False, 'USE_SUBTILING': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': False, 'OUTPUT_MAX': False}, (), (detach,)); view_1 = view_2 = view_3 = sdpa_score0 = to_3 = to_4 = to_6 = to_7 = to_9 = to_10 = to_12 = to_13 = sdpa_mask0 = detach = None
getitem = flex_attention[0]
getitem_1 = flex_attention[1]; getitem_1 = None
getitem_2 = flex_attention[2]; flex_attention = getitem_2 = None

View File

@ -4154,7 +4154,7 @@ class GraphModule(torch.nn.Module):
score_mod_0 = self.score_mod_0
mask_fn_0 = self.mask_fn_0
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
flex_attention = torch.ops.higher_order.flex_attention(l_query_, l_key_, l_value_, score_mod_0, (128, 128, l_block_mask_kv_num_blocks, l_block_mask_kv_indices, l_block_mask_full_kv_num_blocks, l_block_mask_full_kv_indices, l_block_mask_q_num_blocks, l_block_mask_q_indices, l_block_mask_full_q_num_blocks, l_block_mask_full_q_indices, 128, 128, mask_fn_0), 0.5, {'PRESCALE_QK': False, 'USE_SUBTILING': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); l_query_ = l_key_ = l_value_ = score_mod_0 = l_block_mask_kv_num_blocks = l_block_mask_kv_indices = l_block_mask_full_kv_num_blocks = l_block_mask_full_kv_indices = l_block_mask_q_num_blocks = l_block_mask_q_indices = l_block_mask_full_q_num_blocks = l_block_mask_full_q_indices = mask_fn_0 = None
out: "f64[2, 2, 128, 4]" = flex_attention[0]; flex_attention = None
return (out,)
@ -4194,7 +4194,7 @@ class GraphModule(torch.nn.Module):
fw_graph0 = self.fw_graph0
joint_graph0 = self.joint_graph0
mask_graph0 = self.mask_graph0
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (1, 1, full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'USE_SUBTILING': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True, 'OUTPUT_MAX': False}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
getitem_7: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None

View File

@ -56,6 +56,7 @@ serder
serdes
statics
strat
subtile
supercede
supercedes
te

View File

@ -100,7 +100,24 @@ def forward_block_mn(
# m_ij
l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc --
{%- if USE_SUBTILING %}
acc0, acc1 = tl.split(
acc.reshape(BLOCK_M, 2, V_HEAD_DIM_ROUNDED//2).permute((0, 2, 1))
)
acc0 = acc0 * alpha[:, None]
acc1 = acc1 * alpha[:, None]
acc = (
tl.join(acc0, acc1)
.permute(0, 2, 1)
.reshape(BLOCK_M, V_HEAD_DIM_ROUNDED)
)
{%- else %}
acc = acc * alpha[:, None]
{%- endif %}
{%- if USE_TMA %}
v = tl.load_tensor_descriptor(
desc_v,

View File

@ -205,6 +205,11 @@ class FlexKernelOptions(TypedDict, total=False):
This is experimental and may not work on all hardware, currently specific
to NVIDIA GPUs Hopper+. Default: False."""
# pyrefly: ignore [invalid-annotation]
USE_SUBTILING: NotRequired[bool]
"""Whether to subtile the alpha-scaling step before the second matmul (P @ V)
to reduce register pressure for large head dims. Default: False."""
# ROCm-specific options
# pyrefly: ignore [invalid-annotation]
kpack: NotRequired[int]
@ -1243,6 +1248,7 @@ def _apply_kernel_options(
kernel_options = {} if kernel_options is None else dict(kernel_options)
kernel_options.setdefault("PRESCALE_QK", False)
kernel_options.setdefault("USE_SUBTILING", False)
kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False)
kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False)
# This forces all biases grad scatters to be done in the DQ iteration loop of the backwards