mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[FlexAttention] Fix dynamic shaped heads flex_flash check (#165866)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165866 Approved by: https://github.com/BoyuanFeng ghstack dependencies: #165729
This commit is contained in:
committed by
PyTorch MergeBot
parent
8951df03de
commit
6b80c94901
@ -193,7 +193,12 @@ def flex_attention(
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
if _use_flex_flash_attention(subgraph, mask_graph, kernel_options):
|
||||
if _use_flex_flash_attention(
|
||||
subgraph,
|
||||
mask_graph,
|
||||
kernel_options,
|
||||
num_score_mod_placeholders=len(placeholder_inps),
|
||||
):
|
||||
return create_flex_flash_attention_kernel(
|
||||
query,
|
||||
key,
|
||||
|
@ -40,21 +40,25 @@ flash_attention_cutedsl_template = CuteDSLTemplate(
|
||||
)
|
||||
|
||||
|
||||
def input_buffers_require_grads(graph_module):
|
||||
"""Check if any of the input buffers (beyond the first 5) require gradients."""
|
||||
def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int):
|
||||
"""Check if any of the input buffers (beyond the score mod placeholders) require gradients."""
|
||||
inputs = []
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
inputs.append(node)
|
||||
if len(inputs) <= 5:
|
||||
if len(inputs) <= num_score_mod_placeholders:
|
||||
return False
|
||||
for n in inputs[5:]:
|
||||
if n.meta["tensor_meta"].requires_grad:
|
||||
return True
|
||||
return False
|
||||
|
||||
def requires_grad(n):
|
||||
tensor_meta = n.meta.get("tensor_meta")
|
||||
return tensor_meta.requires_grad if tensor_meta is not None else False
|
||||
|
||||
return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:])
|
||||
|
||||
|
||||
def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
||||
def is_trivial_graph(
|
||||
graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int
|
||||
):
|
||||
"""Check if the flex graphs are compatible with Flash Attention."""
|
||||
graph = graph_module.graph
|
||||
nodes = list(graph.nodes)
|
||||
@ -64,7 +68,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
||||
output_val = output[0].args[0]
|
||||
|
||||
if is_score_graph:
|
||||
if input_buffers_require_grads(graph_module):
|
||||
if input_buffers_require_grads(graph_module, num_score_mod_placeholders):
|
||||
return False
|
||||
return True # party on garth
|
||||
# mask mod graph is empty if we have 4 inputs and full_default output
|
||||
@ -72,7 +76,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
||||
|
||||
|
||||
def _can_use_flex_flash_attention(
|
||||
subgraph: Subgraph, mask_graph: Subgraph
|
||||
subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int
|
||||
) -> tuple[bool, str]:
|
||||
"""Check if flex flash attention can be used for the given inputs.
|
||||
|
||||
@ -82,14 +86,22 @@ def _can_use_flex_flash_attention(
|
||||
if not ensure_flash_available():
|
||||
return False, "CUTE flash attention library is not available"
|
||||
|
||||
if input_buffers_require_grads(subgraph.graph_module):
|
||||
if input_buffers_require_grads(subgraph.graph_module, num_score_mod_placeholders):
|
||||
return (
|
||||
False,
|
||||
"Input buffers require gradients (not supported by flash attention)",
|
||||
)
|
||||
|
||||
score_trivial = is_trivial_graph(subgraph.graph_module, is_score_graph=True)
|
||||
mask_trivial = is_trivial_graph(mask_graph.graph_module, is_score_graph=False)
|
||||
score_trivial = is_trivial_graph(
|
||||
subgraph.graph_module,
|
||||
is_score_graph=True,
|
||||
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||
)
|
||||
mask_trivial = is_trivial_graph(
|
||||
mask_graph.graph_module,
|
||||
is_score_graph=False,
|
||||
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||
)
|
||||
|
||||
if not score_trivial and not mask_trivial:
|
||||
return (
|
||||
@ -111,12 +123,17 @@ def _can_use_flex_flash_attention(
|
||||
|
||||
|
||||
def _use_flex_flash_attention(
|
||||
subgraph: Subgraph, mask_graph: Subgraph, kernel_options: dict[str, Any]
|
||||
subgraph: Subgraph,
|
||||
mask_graph: Subgraph,
|
||||
kernel_options: dict[str, Any],
|
||||
num_score_mod_placeholders: int,
|
||||
) -> bool:
|
||||
"""Determine if we should use flex flash attention for the given inputs."""
|
||||
force_flash = kernel_options.get("force_flash", False)
|
||||
|
||||
can_use, reason = _can_use_flex_flash_attention(subgraph, mask_graph)
|
||||
can_use, reason = _can_use_flex_flash_attention(
|
||||
subgraph, mask_graph, num_score_mod_placeholders
|
||||
)
|
||||
|
||||
if force_flash and not can_use:
|
||||
raise RuntimeError(
|
||||
|
Reference in New Issue
Block a user