diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 1081afc25520..1e631b4af389 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1970,6 +1970,38 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): self.run_test(score_mod_scale, dtype, device=device) + @supported_platform + @skip_on_cpu + @dtypes(torch.float16) + @dtypesIfCUDA(torch.float16) + def test_dynamic_captured_buffer(self, device, dtype): + def run_with_head_count(compiled_fa, head_count): + head_scale = torch.randn( + head_count, device=device, dtype=dtype, requires_grad=True + ) + + def score_mod(score, batch, head, token_q, token_kv): + return score * head_scale[head] + + q = torch.randn( + B, head_count, S, D, device=device, dtype=dtype, requires_grad=True + ) + k = torch.randn_like(q, requires_grad=True) + v = torch.randn_like(q, requires_grad=True) + + block_mask = create_block_mask(noop_mask, B, 1, S, S, device=device) + + out = compiled_fa(q, k, v, score_mod=score_mod, block_mask=block_mask) + loss = out.sum() + loss.backward() + return out + + compiled_fa = torch.compile(flex_attention, fullgraph=True, dynamic=True) + + head_counts = [4, 8, 4, 16, 4] + for head_count in head_counts: + run_with_head_count(compiled_fa, head_count) + @supported_platform @dtypes(*device_configs["cpu"].dtypes_fast) @dtypesIfCUDA(*device_configs["cuda"].dtypes_fast) diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index e692b3237121..f62eb70d967e 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -193,7 +193,12 @@ def flex_attention( score_mod_other_buffers, mask_mod_other_buffers, ) - if _use_flex_flash_attention(subgraph, mask_graph, kernel_options): + if _use_flex_flash_attention( + subgraph, + mask_graph, + kernel_options, + num_score_mod_placeholders=len(placeholder_inps), + ): return create_flex_flash_attention_kernel( query, key, diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 5fedcedf6488..e613354a8925 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -41,21 +41,25 @@ flash_attention_cutedsl_template = CuteDSLTemplate( ) -def input_buffers_require_grads(graph_module): - """Check if any of the input buffers (beyond the first 5) require gradients.""" +def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int): + """Check if any of the input buffers (beyond the score mod placeholders) require gradients.""" inputs = [] for node in graph_module.graph.nodes: if node.op == "placeholder": inputs.append(node) - if len(inputs) <= 5: + if len(inputs) <= num_score_mod_placeholders: return False - for n in inputs[5:]: - if n.meta["tensor_meta"].requires_grad: - return True - return False + + def requires_grad(n): + tensor_meta = n.meta.get("tensor_meta") + return tensor_meta.requires_grad if tensor_meta is not None else False + + return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:]) -def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool): +def is_trivial_graph( + graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int +): """Check if the flex graphs are compatible with Flash Attention.""" graph = graph_module.graph nodes = list(graph.nodes) @@ -65,7 +69,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool): output_val = output[0].args[0] if is_score_graph: - if input_buffers_require_grads(graph_module): + if input_buffers_require_grads(graph_module, num_score_mod_placeholders): return False return True # party on garth # mask mod graph is empty if we have 4 inputs and full_default output @@ -73,7 +77,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool): def _can_use_flex_flash_attention( - subgraph: Subgraph, mask_graph: Subgraph + subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int ) -> tuple[bool, str]: """Check if flex flash attention can be used for the given inputs. @@ -83,14 +87,22 @@ def _can_use_flex_flash_attention( if not ensure_flash_available(): return False, "CUTE flash attention library is not available" - if input_buffers_require_grads(subgraph.graph_module): + if input_buffers_require_grads(subgraph.graph_module, num_score_mod_placeholders): return ( False, "Input buffers require gradients (not supported by flash attention)", ) - score_trivial = is_trivial_graph(subgraph.graph_module, is_score_graph=True) - mask_trivial = is_trivial_graph(mask_graph.graph_module, is_score_graph=False) + score_trivial = is_trivial_graph( + subgraph.graph_module, + is_score_graph=True, + num_score_mod_placeholders=num_score_mod_placeholders, + ) + mask_trivial = is_trivial_graph( + mask_graph.graph_module, + is_score_graph=False, + num_score_mod_placeholders=num_score_mod_placeholders, + ) if not score_trivial and not mask_trivial: return ( @@ -112,12 +124,17 @@ def _can_use_flex_flash_attention( def _use_flex_flash_attention( - subgraph: Subgraph, mask_graph: Subgraph, kernel_options: dict[str, Any] + subgraph: Subgraph, + mask_graph: Subgraph, + kernel_options: dict[str, Any], + num_score_mod_placeholders: int, ) -> bool: """Determine if we should use flex flash attention for the given inputs.""" force_flash = kernel_options.get("force_flash", False) - can_use, reason = _can_use_flex_flash_attention(subgraph, mask_graph) + can_use, reason = _can_use_flex_flash_attention( + subgraph, mask_graph, num_score_mod_placeholders + ) if force_flash and not can_use: raise RuntimeError(