mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix SAC + Flex issue (#164421)
# Summary This happends when flex_attention is not tagged with the ` CheckpointPolicy.MUST_SAVE` policy. This causes the lse to be unrealized. I think in general this probably not the best policy but we shoudn't error Pull Request resolved: https://github.com/pytorch/pytorch/pull/164421 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
0e5773b7fa
commit
cfd46d13e6
@ -4571,6 +4571,111 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
torch.testing.assert_close(grad_module, grad_compiled, rtol=1e-2, atol=1e-2)
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_selective_ac_with_max_autotune_short_query(self, device):
|
||||
from functools import partial
|
||||
|
||||
from torch.utils.checkpoint import (
|
||||
checkpoint,
|
||||
CheckpointPolicy,
|
||||
create_selective_checkpoint_contexts,
|
||||
)
|
||||
|
||||
compute_intensive_ops = [
|
||||
torch.ops.aten.mm,
|
||||
torch.ops.aten.bmm,
|
||||
]
|
||||
|
||||
def policy_fn(ctx, op, *args, **kwargs):
|
||||
if op in compute_intensive_ops:
|
||||
return CheckpointPolicy.MUST_SAVE
|
||||
else:
|
||||
return CheckpointPolicy.PREFER_RECOMPUTE
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
class DummyAttentionModule(nn.Module):
|
||||
def __init__(self, dim=64, num_heads=4):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_proj = nn.Linear(dim, dim)
|
||||
self.k_proj = nn.Linear(dim, dim)
|
||||
self.v_proj = nn.Linear(dim, dim)
|
||||
self.out_proj = nn.Linear(dim, dim)
|
||||
|
||||
self._activation_checkpoint_context_fn = partial(
|
||||
create_selective_checkpoint_contexts, policy_fn
|
||||
)
|
||||
|
||||
self._flex_attention = torch.compile(
|
||||
partial(
|
||||
checkpoint,
|
||||
flex_attention,
|
||||
use_reentrant=False,
|
||||
context_fn=self._activation_checkpoint_context_fn,
|
||||
),
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
)
|
||||
|
||||
def forward(self, x, block_mask):
|
||||
batch_size, seq_len, _ = x.shape
|
||||
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(x)
|
||||
v = self.v_proj(x)
|
||||
|
||||
q = q.view(
|
||||
batch_size, seq_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
k = k.view(
|
||||
batch_size, seq_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
v = v.view(
|
||||
batch_size, seq_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
attn_out = self._flex_attention(q, k, v, block_mask=block_mask)
|
||||
|
||||
attn_out = (
|
||||
attn_out.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, seq_len, self.dim)
|
||||
)
|
||||
|
||||
out = self.out_proj(attn_out)
|
||||
|
||||
return out
|
||||
|
||||
batch_size = 2
|
||||
seq_len = 64
|
||||
dim = 64
|
||||
num_heads = 4
|
||||
|
||||
model = DummyAttentionModule(dim=dim, num_heads=num_heads).to(device)
|
||||
|
||||
x = torch.randn(batch_size, seq_len, dim, device=device, requires_grad=True)
|
||||
|
||||
block_mask = create_block_mask(
|
||||
causal_mask,
|
||||
B=batch_size,
|
||||
H=num_heads,
|
||||
Q_LEN=seq_len,
|
||||
KV_LEN=seq_len,
|
||||
device=device,
|
||||
)
|
||||
|
||||
out = model(x, block_mask)
|
||||
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
|
||||
self.assertIsNotNone(x.grad)
|
||||
|
||||
@supported_platform
|
||||
@skip_on_cpu
|
||||
def test_validate_small_embedding_size_error_message(self, device):
|
||||
|
@ -569,6 +569,7 @@ def flex_attention_backward(*args, **kwargs):
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
logsumexp,
|
||||
grad_out,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
@ -583,6 +584,7 @@ def flex_attention_backward(*args, **kwargs):
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
logsumexp,
|
||||
grad_out,
|
||||
kv_num_blocks,
|
||||
kv_indices,
|
||||
|
Reference in New Issue
Block a user