mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update
[ghstack-poisoned]
This commit is contained in:
@ -1973,6 +1973,38 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
|
||||
|
||||
self.run_test(score_mod_scale, dtype, device=device)
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
@dtypes(torch.float16)
|
||||
@dtypesIfCUDA(torch.float16)
|
||||
def test_dynamic_captured_buffer(self, device, dtype):
|
||||
def run_with_head_count(compiled_fa, head_count):
|
||||
head_scale = torch.randn(
|
||||
head_count, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
|
||||
def score_mod(score, batch, head, token_q, token_kv):
|
||||
return score * head_scale[head]
|
||||
|
||||
q = torch.randn(
|
||||
B, head_count, S, D, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
k = torch.randn_like(q, requires_grad=True)
|
||||
v = torch.randn_like(q, requires_grad=True)
|
||||
|
||||
block_mask = create_block_mask(noop_mask, B, 1, S, S, device=device)
|
||||
|
||||
out = compiled_fa(q, k, v, score_mod=score_mod, block_mask=block_mask)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
return out
|
||||
|
||||
compiled_fa = torch.compile(flex_attention, fullgraph=True, dynamic=True)
|
||||
|
||||
head_counts = [4, 8, 4, 16, 4]
|
||||
for head_count in head_counts:
|
||||
run_with_head_count(compiled_fa, head_count)
|
||||
|
||||
@supported_platform
|
||||
@dtypes(*device_configs["cpu"].dtypes_fast)
|
||||
@dtypesIfCUDA(*device_configs["cuda"].dtypes_fast)
|
||||
|
@ -193,7 +193,12 @@ def flex_attention(
|
||||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
if _use_flex_flash_attention(subgraph, mask_graph, kernel_options):
|
||||
if _use_flex_flash_attention(
|
||||
subgraph,
|
||||
mask_graph,
|
||||
kernel_options,
|
||||
num_score_mod_placeholders=len(placeholder_inps),
|
||||
):
|
||||
return create_flex_flash_attention_kernel(
|
||||
query,
|
||||
key,
|
||||
|
@ -41,21 +41,25 @@ flash_attention_cutedsl_template = CuteDSLTemplate(
|
||||
)
|
||||
|
||||
|
||||
def input_buffers_require_grads(graph_module):
|
||||
"""Check if any of the input buffers (beyond the first 5) require gradients."""
|
||||
def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int):
|
||||
"""Check if any of the input buffers (beyond the score mod placeholders) require gradients."""
|
||||
inputs = []
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
inputs.append(node)
|
||||
if len(inputs) <= 5:
|
||||
if len(inputs) <= num_score_mod_placeholders:
|
||||
return False
|
||||
for n in inputs[5:]:
|
||||
if n.meta["tensor_meta"].requires_grad:
|
||||
return True
|
||||
return False
|
||||
|
||||
def requires_grad(n):
|
||||
tensor_meta = n.meta.get("tensor_meta")
|
||||
return tensor_meta.requires_grad if tensor_meta is not None else False
|
||||
|
||||
return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:])
|
||||
|
||||
|
||||
def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
||||
def is_trivial_graph(
|
||||
graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int
|
||||
):
|
||||
"""Check if the flex graphs are compatible with Flash Attention."""
|
||||
graph = graph_module.graph
|
||||
nodes = list(graph.nodes)
|
||||
@ -65,7 +69,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
||||
output_val = output[0].args[0]
|
||||
|
||||
if is_score_graph:
|
||||
if input_buffers_require_grads(graph_module):
|
||||
if input_buffers_require_grads(graph_module, num_score_mod_placeholders):
|
||||
return False
|
||||
return True # party on garth
|
||||
# mask mod graph is empty if we have 4 inputs and full_default output
|
||||
@ -73,7 +77,7 @@ def is_trivial_graph(graph_module: GraphModule, is_score_graph: bool):
|
||||
|
||||
|
||||
def _can_use_flex_flash_attention(
|
||||
subgraph: Subgraph, mask_graph: Subgraph
|
||||
subgraph: Subgraph, mask_graph: Subgraph, num_score_mod_placeholders: int
|
||||
) -> tuple[bool, str]:
|
||||
"""Check if flex flash attention can be used for the given inputs.
|
||||
|
||||
@ -83,14 +87,22 @@ def _can_use_flex_flash_attention(
|
||||
if not ensure_flash_available():
|
||||
return False, "CUTE flash attention library is not available"
|
||||
|
||||
if input_buffers_require_grads(subgraph.graph_module):
|
||||
if input_buffers_require_grads(subgraph.graph_module, num_score_mod_placeholders):
|
||||
return (
|
||||
False,
|
||||
"Input buffers require gradients (not supported by flash attention)",
|
||||
)
|
||||
|
||||
score_trivial = is_trivial_graph(subgraph.graph_module, is_score_graph=True)
|
||||
mask_trivial = is_trivial_graph(mask_graph.graph_module, is_score_graph=False)
|
||||
score_trivial = is_trivial_graph(
|
||||
subgraph.graph_module,
|
||||
is_score_graph=True,
|
||||
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||
)
|
||||
mask_trivial = is_trivial_graph(
|
||||
mask_graph.graph_module,
|
||||
is_score_graph=False,
|
||||
num_score_mod_placeholders=num_score_mod_placeholders,
|
||||
)
|
||||
|
||||
if not score_trivial and not mask_trivial:
|
||||
return (
|
||||
@ -112,12 +124,17 @@ def _can_use_flex_flash_attention(
|
||||
|
||||
|
||||
def _use_flex_flash_attention(
|
||||
subgraph: Subgraph, mask_graph: Subgraph, kernel_options: dict[str, Any]
|
||||
subgraph: Subgraph,
|
||||
mask_graph: Subgraph,
|
||||
kernel_options: dict[str, Any],
|
||||
num_score_mod_placeholders: int,
|
||||
) -> bool:
|
||||
"""Determine if we should use flex flash attention for the given inputs."""
|
||||
force_flash = kernel_options.get("force_flash", False)
|
||||
|
||||
can_use, reason = _can_use_flex_flash_attention(subgraph, mask_graph)
|
||||
can_use, reason = _can_use_flex_flash_attention(
|
||||
subgraph, mask_graph, num_score_mod_placeholders
|
||||
)
|
||||
|
||||
if force_flash and not can_use:
|
||||
raise RuntimeError(
|
||||
|
Reference in New Issue
Block a user