mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[FlexAttention] Fix dynamic shaped heads flex_flash check
ghstack-source-id: 9b9ede68b091ae3bf97433c8210321638a5dcbcf Pull-Request: https://github.com/pytorch/pytorch/pull/165866
This commit is contained in:
@ -1970,6 +1970,38 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
|||||||
|
|
||||||
self.run_test(score_mod_scale, dtype, device=device)
|
self.run_test(score_mod_scale, dtype, device=device)
|
||||||
|
|
||||||
|
@supported_platform
|
||||||
|
@skip_on_cpu
|
||||||
|
@dtypes(torch.float16)
|
||||||
|
@dtypesIfCUDA(torch.float16)
|
||||||
|
def test_dynamic_captured_buffer(self, device, dtype):
|
||||||
|
def run_with_head_count(compiled_fa, head_count):
|
||||||
|
head_scale = torch.randn(
|
||||||
|
head_count, device=device, dtype=dtype, requires_grad=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def score_mod(score, batch, head, token_q, token_kv):
|
||||||
|
return score * head_scale[head]
|
||||||
|
|
||||||
|
q = torch.randn(
|
||||||
|
B, head_count, S, D, device=device, dtype=dtype, requires_grad=True
|
||||||
|
)
|
||||||
|
k = torch.randn_like(q, requires_grad=True)
|
||||||
|
v = torch.randn_like(q, requires_grad=True)
|
||||||
|
|
||||||
|
block_mask = create_block_mask(noop_mask, B, 1, S, S, device=device)
|
||||||
|
|
||||||
|
out = compiled_fa(q, k, v, score_mod=score_mod, block_mask=block_mask)
|
||||||
|
loss = out.sum()
|
||||||
|
loss.backward()
|
||||||
|
return out
|
||||||
|
|
||||||
|
compiled_fa = torch.compile(flex_attention, fullgraph=True, dynamic=True)
|
||||||
|
|
||||||
|
head_counts = [4, 8, 4, 16, 4]
|
||||||
|
for head_count in head_counts:
|
||||||
|
run_with_head_count(compiled_fa, head_count)
|
||||||
|
|
||||||
@supported_platform
|
@supported_platform
|
||||||
@dtypes(*device_configs["cpu"].dtypes_fast)
|
@dtypes(*device_configs["cpu"].dtypes_fast)
|
||||||
@dtypesIfCUDA(*device_configs["cuda"].dtypes_fast)
|
@dtypesIfCUDA(*device_configs["cuda"].dtypes_fast)
|
||||||
|
@ -193,7 +193,12 @@ def flex_attention(
|
|||||||
score_mod_other_buffers,
|
score_mod_other_buffers,
|
||||||
mask_mod_other_buffers,
|
mask_mod_other_buffers,
|
||||||
)
|
)
|
||||||
if _use_flex_flash_attention(subgraph, mask_graph, kernel_options):
|
if _use_flex_flash_attention(
|
||||||
|
subgraph,
|
||||||
|
mask_graph,
|
||||||
|
kernel_options,
|
||||||
|
num_score_mod_placeholders=len(placeholder_inps),
|
||||||
|
):
|
||||||
return create_flex_flash_attention_kernel(
|
return create_flex_flash_attention_kernel(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
@ -41,21 +41,25 @@ flash_attention_cutedsl_template = CuteDSLTemplate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def input_buffers_require_grads(graph_module):
|
def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int):
|
||||||
"""Check if any of the input buffers (beyond the first 5) require gradients."""
|
"""Check if any of the input buffers (beyond the score mod placeholders) require gradients."""
|
||||||
inputs = []
|
inputs = []
|
||||||
for node in graph_module.graph.nodes:
|
for node in graph_module.graph.nodes:
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
inputs.append(node)
|
inputs.append(node)
|
||||||
if len(inputs) <= 5:
|
if len(inputs) <= num_score_mod_placeholders:
|
||||||
return False
|
return False
|
||||||
for n in inputs[5:]:
|
|
||||||
if n.meta["tensor_meta"].requires_grad:
|
def requires_grad(n):
|
||||||
return True
|
tensor_meta = n.meta.get("tensor_meta")
|
||||||
return False
|
return tensor_meta.requires_grad if tensor_meta is not None else False
|
||||||
|
|
||||||
|
return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:])
|
||||||
|
|
||||||
|
|
||||||
def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
def is_trivial_graph(
|
||||||
|
graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int
|
||||||
|
):
|
||||||
"""Check if the flex graphs are compatible with Flash Attention."""
|
"""Check if the flex graphs are compatible with Flash Attention."""
|
||||||
graph = graph_module.graph
|
graph = graph_module.graph
|
||||||
nodes = list(graph.nodes)
|
nodes = list(graph.nodes)
|
||||||
@ -65,7 +69,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
|||||||
output_val = output[0].args[0]
|
output_val = output[0].args[0]
|
||||||
|
|
||||||
if is_score_graph:
|
if is_score_graph:
|
||||||
if input_buffers_require_grads(graph_module):
|
if input_buffers_require_grads(graph_module, num_score_mod_placeholders):
|
||||||
return False
|
return False
|
||||||
return True # party on garth
|
return True # party on garth
|
||||||
# mask mod graph is empty if we have 4 inputs and full_default output
|
# mask mod graph is empty if we have 4 inputs and full_default output
|
||||||
@ -73,7 +77,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
|||||||
|
|
||||||
|
|
||||||
def _can_use_flex_flash_attention(
|
def _can_use_flex_flash_attention(
|
||||||
subgraph: Subgraph, mask_graph: Subgraph
|
subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""Check if flex flash attention can be used for the given inputs.
|
"""Check if flex flash attention can be used for the given inputs.
|
||||||
|
|
||||||
@ -83,14 +87,22 @@ def _can_use_flex_flash_attention(
|
|||||||
if not ensure_flash_available():
|
if not ensure_flash_available():
|
||||||
return False, "CUTE flash attention library is not available"
|
return False, "CUTE flash attention library is not available"
|
||||||
|
|
||||||
if input_buffers_require_grads(subgraph.graph_module):
|
if input_buffers_require_grads(subgraph.graph_module, num_score_mod_placeholders):
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
"Input buffers require gradients (not supported by flash attention)",
|
"Input buffers require gradients (not supported by flash attention)",
|
||||||
)
|
)
|
||||||
|
|
||||||
score_trivial = is_trivial_graph(subgraph.graph_module, is_score_graph=True)
|
score_trivial = is_trivial_graph(
|
||||||
mask_trivial = is_trivial_graph(mask_graph.graph_module, is_score_graph=False)
|
subgraph.graph_module,
|
||||||
|
is_score_graph=True,
|
||||||
|
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||||
|
)
|
||||||
|
mask_trivial = is_trivial_graph(
|
||||||
|
mask_graph.graph_module,
|
||||||
|
is_score_graph=False,
|
||||||
|
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||||
|
)
|
||||||
|
|
||||||
if not score_trivial and not mask_trivial:
|
if not score_trivial and not mask_trivial:
|
||||||
return (
|
return (
|
||||||
@ -112,12 +124,17 @@ def _can_use_flex_flash_attention(
|
|||||||
|
|
||||||
|
|
||||||
def _use_flex_flash_attention(
|
def _use_flex_flash_attention(
|
||||||
subgraph: Subgraph, mask_graph: Subgraph, kernel_options: dict[str, Any]
|
subgraph: Subgraph,
|
||||||
|
mask_graph: Subgraph,
|
||||||
|
kernel_options: dict[str, Any],
|
||||||
|
num_score_mod_placeholders: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Determine if we should use flex flash attention for the given inputs."""
|
"""Determine if we should use flex flash attention for the given inputs."""
|
||||||
force_flash = kernel_options.get("force_flash", False)
|
force_flash = kernel_options.get("force_flash", False)
|
||||||
|
|
||||||
can_use, reason = _can_use_flex_flash_attention(subgraph, mask_graph)
|
can_use, reason = _can_use_flex_flash_attention(
|
||||||
|
subgraph, mask_graph, num_score_mod_placeholders
|
||||||
|
)
|
||||||
|
|
||||||
if force_flash and not can_use:
|
if force_flash and not can_use:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
Reference in New Issue
Block a user