[FlexAttention] Fix dynamic shaped heads flex_flash check (#165866)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165866
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #165729
This commit is contained in:
drisspg
2025-10-19 17:58:33 +00:00
committed by PyTorch MergeBot
parent 8951df03de
commit 6b80c94901
3 changed files with 70 additions and 16 deletions

View File

@ -193,7 +193,12 @@ def flex_attention(
score_mod_other_buffers,
mask_mod_other_buffers,
)
if _use_flex_flash_attention(subgraph, mask_graph, kernel_options):
if _use_flex_flash_attention(
subgraph,
mask_graph,
kernel_options,
num_score_mod_placeholders=len(placeholder_inps),
):
return create_flex_flash_attention_kernel(
query,
key,

View File

@ -40,21 +40,25 @@ flash_attention_cutedsl_template = CuteDSLTemplate(
)
def input_buffers_require_grads(graph_module):
"""Check if any of the input buffers (beyond the first 5) require gradients."""
def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int):
"""Check if any of the input buffers (beyond the score mod placeholders) require gradients."""
inputs = []
for node in graph_module.graph.nodes:
if node.op == "placeholder":
inputs.append(node)
if len(inputs) <= 5:
if len(inputs) <= num_score_mod_placeholders:
return False
for n in inputs[5:]:
if n.meta["tensor_meta"].requires_grad:
return True
return False
def requires_grad(n):
tensor_meta = n.meta.get("tensor_meta")
return tensor_meta.requires_grad if tensor_meta is not None else False
return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:])
def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
def is_trivial_graph(
graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int
):
"""Check if the flex graphs are compatible with Flash Attention."""
graph = graph_module.graph
nodes = list(graph.nodes)
@ -64,7 +68,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
output_val = output[0].args[0]
if is_score_graph:
if input_buffers_require_grads(graph_module):
if input_buffers_require_grads(graph_module, num_score_mod_placeholders):
return False
return True # party on garth
# mask mod graph is empty if we have 4 inputs and full_default output
@ -72,7 +76,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
def _can_use_flex_flash_attention(
subgraph: Subgraph, mask_graph: Subgraph
subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int
) -> tuple[bool, str]:
"""Check if flex flash attention can be used for the given inputs.
@ -82,14 +86,22 @@ def _can_use_flex_flash_attention(
if not ensure_flash_available():
return False, "CUTE flash attention library is not available"
if input_buffers_require_grads(subgraph.graph_module):
if input_buffers_require_grads(subgraph.graph_module, num_score_mod_placeholders):
return (
False,
"Input buffers require gradients (not supported by flash attention)",
)
score_trivial = is_trivial_graph(subgraph.graph_module, is_score_graph=True)
mask_trivial = is_trivial_graph(mask_graph.graph_module, is_score_graph=False)
score_trivial = is_trivial_graph(
subgraph.graph_module,
is_score_graph=True,
num_score_mod_placeholders=num_score_mod_placeholders,
)
mask_trivial = is_trivial_graph(
mask_graph.graph_module,
is_score_graph=False,
num_score_mod_placeholders=num_score_mod_placeholders,
)
if not score_trivial and not mask_trivial:
return (
@ -111,12 +123,17 @@ def _can_use_flex_flash_attention(
def _use_flex_flash_attention(
subgraph: Subgraph, mask_graph: Subgraph, kernel_options: dict[str, Any]
subgraph: Subgraph,
mask_graph: Subgraph,
kernel_options: dict[str, Any],
num_score_mod_placeholders: int,
) -> bool:
"""Determine if we should use flex flash attention for the given inputs."""
force_flash = kernel_options.get("force_flash", False)
can_use, reason = _can_use_flex_flash_attention(subgraph, mask_graph)
can_use, reason = _can_use_flex_flash_attention(
subgraph, mask_graph, num_score_mod_placeholders
)
if force_flash and not can_use:
raise RuntimeError(