[FlexAttention] Fix dynamic shaped heads flex_flash check

ghstack-source-id: 9b9ede68b091ae3bf97433c8210321638a5dcbcf
Pull-Request: https://github.com/pytorch/pytorch/pull/165866
This commit is contained in:
drisspg
2025-10-19 17:58:33 +00:00
parent 3734178a49
commit a317caf67e
3 changed files with 70 additions and 16 deletions

View File

@ -1970,6 +1970,38 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test(score_mod_scale, dtype, device=device)
@supported_platform
@skip_on_cpu
@dtypes(torch.float16)
@dtypesIfCUDA(torch.float16)
def test_dynamic_captured_buffer(self, device, dtype):
def run_with_head_count(compiled_fa, head_count):
head_scale = torch.randn(
head_count, device=device, dtype=dtype, requires_grad=True
)
def score_mod(score, batch, head, token_q, token_kv):
return score * head_scale[head]
q = torch.randn(
B, head_count, S, D, device=device, dtype=dtype, requires_grad=True
)
k = torch.randn_like(q, requires_grad=True)
v = torch.randn_like(q, requires_grad=True)
block_mask = create_block_mask(noop_mask, B, 1, S, S, device=device)
out = compiled_fa(q, k, v, score_mod=score_mod, block_mask=block_mask)
loss = out.sum()
loss.backward()
return out
compiled_fa = torch.compile(flex_attention, fullgraph=True, dynamic=True)
head_counts = [4, 8, 4, 16, 4]
for head_count in head_counts:
run_with_head_count(compiled_fa, head_count)
@supported_platform
@dtypes(*device_configs["cpu"].dtypes_fast)
@dtypesIfCUDA(*device_configs["cuda"].dtypes_fast)

View File

@ -193,7 +193,12 @@ def flex_attention(
score_mod_other_buffers,
mask_mod_other_buffers,
)
if _use_flex_flash_attention(subgraph, mask_graph, kernel_options):
if _use_flex_flash_attention(
subgraph,
mask_graph,
kernel_options,
num_score_mod_placeholders=len(placeholder_inps),
):
return create_flex_flash_attention_kernel(
query,
key,

View File

@ -41,21 +41,25 @@ flash_attention_cutedsl_template = CuteDSLTemplate(
)
def input_buffers_require_grads(graph_module):
"""Check if any of the input buffers (beyond the first 5) require gradients."""
def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int):
"""Check if any of the input buffers (beyond the score mod placeholders) require gradients."""
inputs = []
for node in graph_module.graph.nodes:
if node.op == "placeholder":
inputs.append(node)
if len(inputs) <= 5:
if len(inputs) <= num_score_mod_placeholders:
return False
for n in inputs[5:]:
if n.meta["tensor_meta"].requires_grad:
return True
return False
def requires_grad(n):
tensor_meta = n.meta.get("tensor_meta")
return tensor_meta.requires_grad if tensor_meta is not None else False
return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:])
def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
def is_trivial_graph(
graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int
):
"""Check if the flex graphs are compatible with Flash Attention."""
graph = graph_module.graph
nodes = list(graph.nodes)
@ -65,7 +69,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
output_val = output[0].args[0]
if is_score_graph:
if input_buffers_require_grads(graph_module):
if input_buffers_require_grads(graph_module, num_score_mod_placeholders):
return False
return True # party on garth
# mask mod graph is empty if we have 4 inputs and full_default output
@ -73,7 +77,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
def _can_use_flex_flash_attention(
subgraph: Subgraph, mask_graph: Subgraph
subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int
) -> tuple[bool, str]:
"""Check if flex flash attention can be used for the given inputs.
@ -83,14 +87,22 @@ def _can_use_flex_flash_attention(
if not ensure_flash_available():
return False, "CUTE flash attention library is not available"
if input_buffers_require_grads(subgraph.graph_module):
if input_buffers_require_grads(subgraph.graph_module, num_score_mod_placeholders):
return (
False,
"Input buffers require gradients (not supported by flash attention)",
)
score_trivial = is_trivial_graph(subgraph.graph_module, is_score_graph=True)
mask_trivial = is_trivial_graph(mask_graph.graph_module, is_score_graph=False)
score_trivial = is_trivial_graph(
subgraph.graph_module,
is_score_graph=True,
num_score_mod_placeholders=num_score_mod_placeholders,
)
mask_trivial = is_trivial_graph(
mask_graph.graph_module,
is_score_graph=False,
num_score_mod_placeholders=num_score_mod_placeholders,
)
if not score_trivial and not mask_trivial:
return (
@ -112,12 +124,17 @@ def _can_use_flex_flash_attention(
def _use_flex_flash_attention(
subgraph: Subgraph, mask_graph: Subgraph, kernel_options: dict[str, Any]
subgraph: Subgraph,
mask_graph: Subgraph,
kernel_options: dict[str, Any],
num_score_mod_placeholders: int,
) -> bool:
"""Determine if we should use flex flash attention for the given inputs."""
force_flash = kernel_options.get("force_flash", False)
can_use, reason = _can_use_flex_flash_attention(subgraph, mask_graph)
can_use, reason = _can_use_flex_flash_attention(
subgraph, mask_graph, num_score_mod_placeholders
)
if force_flash and not can_use:
raise RuntimeError(