Fix SAC + Flex issue (#164421)

# Summary

This happends when flex_attention is not tagged with the ` CheckpointPolicy.MUST_SAVE` policy. This causes the lse to be unrealized. I think in general this probably not the best policy but we shoudn't error

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164421
Approved by: https://github.com/Skylion007
This commit is contained in:
drisspg
2025-10-02 00:17:18 +00:00
committed by PyTorch MergeBot
parent 0e5773b7fa
commit cfd46d13e6
2 changed files with 107 additions and 0 deletions

View File

@ -4571,6 +4571,111 @@ class GraphModule(torch.nn.Module):
torch.testing.assert_close(grad_module, grad_compiled, rtol=1e-2, atol=1e-2)
@supported_platform
@skip_on_cpu
def test_selective_ac_with_max_autotune_short_query(self, device):
from functools import partial
from torch.utils.checkpoint import (
checkpoint,
CheckpointPolicy,
create_selective_checkpoint_contexts,
)
compute_intensive_ops = [
torch.ops.aten.mm,
torch.ops.aten.bmm,
]
def policy_fn(ctx, op, *args, **kwargs):
if op in compute_intensive_ops:
return CheckpointPolicy.MUST_SAVE
else:
return CheckpointPolicy.PREFER_RECOMPUTE
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
class DummyAttentionModule(nn.Module):
def __init__(self, dim=64, num_heads=4):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
self._activation_checkpoint_context_fn = partial(
create_selective_checkpoint_contexts, policy_fn
)
self._flex_attention = torch.compile(
partial(
checkpoint,
flex_attention,
use_reentrant=False,
context_fn=self._activation_checkpoint_context_fn,
),
mode="max-autotune-no-cudagraphs",
)
def forward(self, x, block_mask):
batch_size, seq_len, _ = x.shape
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q = q.view(
batch_size, seq_len, self.num_heads, self.head_dim
).transpose(1, 2)
k = k.view(
batch_size, seq_len, self.num_heads, self.head_dim
).transpose(1, 2)
v = v.view(
batch_size, seq_len, self.num_heads, self.head_dim
).transpose(1, 2)
attn_out = self._flex_attention(q, k, v, block_mask=block_mask)
attn_out = (
attn_out.transpose(1, 2)
.contiguous()
.view(batch_size, seq_len, self.dim)
)
out = self.out_proj(attn_out)
return out
batch_size = 2
seq_len = 64
dim = 64
num_heads = 4
model = DummyAttentionModule(dim=dim, num_heads=num_heads).to(device)
x = torch.randn(batch_size, seq_len, dim, device=device, requires_grad=True)
block_mask = create_block_mask(
causal_mask,
B=batch_size,
H=num_heads,
Q_LEN=seq_len,
KV_LEN=seq_len,
device=device,
)
out = model(x, block_mask)
loss = out.sum()
loss.backward()
self.assertIsNotNone(x.grad)
@supported_platform
@skip_on_cpu
def test_validate_small_embedding_size_error_message(self, device):

View File

@ -569,6 +569,7 @@ def flex_attention_backward(*args, **kwargs):
query,
key,
value,
logsumexp,
grad_out,
kv_num_blocks,
kv_indices,
@ -583,6 +584,7 @@ def flex_attention_backward(*args, **kwargs):
query,
key,
value,
logsumexp,
grad_out,
kv_num_blocks,
kv_indices,