[NJT][flop counter] attention: if offsets are fake, use max seqlen (#132356)

The flop counter is used by the partitioner, in which case the tensors passed in can be fake.

The flop computations for nested attention use the offsets to determine the actual amount of compute that will be done. But when the offsets are fake, we end up with unbacked symints (from `(offsets[1:] - offsets[:-1]).to_list()`). If we find that the offsets are fake or functional tensors, then use the max sequence length instead.

Repro: https://gist.github.com/davidberard98/903fb3e586edb6d1d466786e1a610eba

Differential Revision: [D60597463](https://our.internmc.facebook.com/intern/diff/D60597463)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132356
Approved by: https://github.com/soulitzer
This commit is contained in:
David Berard
2024-08-02 09:53:42 -07:00
committed by PyTorch MergeBot
parent 37c3d503b7
commit 1962f9475f
2 changed files with 64 additions and 4 deletions

View File

@ -6,6 +6,7 @@ import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.flop_counter import torch.utils.flop_counter
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION, PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
@ -675,6 +676,53 @@ class TestFlopCounter(TestCase):
), ),
) )
@skipIfRocm # Nested tensor
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
)
def test_nested_attention_fake_tensors(self):
x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16)
offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda")
max_seqlen = 40
with FakeTensorMode() as fake_mode:
fake_x = fake_mode.from_tensor(x)
fake_offsets = fake_mode.from_tensor(offsets)
with FlopCounterMode() as fake_flop_counter_mode:
torch.ops.aten._flash_attention_forward(
fake_x,
fake_x,
fake_x,
fake_offsets,
fake_offsets,
max_seqlen,
max_seqlen,
0.0,
False,
False,
)
dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
with FlopCounterMode() as real_flop_counter_mode:
torch.ops.aten._flash_attention_forward(
dense_x,
dense_x,
dense_x,
None,
None,
max_seqlen,
max_seqlen,
0.0,
False,
False,
)
self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))
def test_addmm_out(self): def test_addmm_out(self):
def f(x): def f(x):
y = torch.zeros(10, 10) y = torch.zeros(10, 10)

View File

@ -265,6 +265,18 @@ def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwar
return sdpa_flop_count(query_shape, key_shape, value_shape) return sdpa_flop_count(query_shape, key_shape, value_shape)
def _offsets_to_lengths(offsets, max_len):
"""
If the offsets tensor is fake, then we don't know the actual lengths.
In that case, we can just assume the worst case; each batch has max length.
"""
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
if not isinstance(offsets, (FakeTensor, FunctionalTensor)):
return offsets.diff().tolist()
return [max_len] * (offsets.size(0) - 1)
def _unpack_flash_attention_nested_shapes( def _unpack_flash_attention_nested_shapes(
*, *,
query, query,
@ -298,8 +310,8 @@ def _unpack_flash_attention_nested_shapes(
assert cum_seq_q is not None assert cum_seq_q is not None
assert cum_seq_k is not None assert cum_seq_k is not None
assert cum_seq_q.shape == cum_seq_k.shape assert cum_seq_q.shape == cum_seq_k.shape
seq_q_lengths = (cum_seq_q[1:] - cum_seq_q[:-1]).tolist() seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
seq_k_lengths = (cum_seq_k[1:] - cum_seq_k[:-1]).tolist() seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths): for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
new_query_shape = (1, h_q, seq_q_len, d_q) new_query_shape = (1, h_q, seq_q_len, d_q)
new_key_shape = (1, h_k, seq_k_len, d_k) new_key_shape = (1, h_k, seq_k_len, d_k)
@ -346,8 +358,8 @@ def _unpack_efficient_attention_nested_shapes(
assert cu_seqlens_q is not None assert cu_seqlens_q is not None
assert cu_seqlens_k is not None assert cu_seqlens_k is not None
assert cu_seqlens_q.shape == cu_seqlens_k.shape assert cu_seqlens_q.shape == cu_seqlens_k.shape
seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist() seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).tolist() seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
for len_q, len_k in zip(seqlens_q, seqlens_k): for len_q, len_k in zip(seqlens_q, seqlens_k):
new_query_shape = (1, h_q, len_q, d_q) new_query_shape = (1, h_q, len_q, d_q)
new_key_shape = (1, h_k, len_k, d_k) new_key_shape = (1, h_k, len_k, d_k)