mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[NJT][flop counter] attention: if offsets are fake, use max seqlen (#132356)
The flop counter is used by the partitioner, in which case the tensors passed in can be fake. The flop computations for nested attention use the offsets to determine the actual amount of compute that will be done. But when the offsets are fake, we end up with unbacked symints (from `(offsets[1:] - offsets[:-1]).to_list()`). If we find that the offsets are fake or functional tensors, then use the max sequence length instead. Repro: https://gist.github.com/davidberard98/903fb3e586edb6d1d466786e1a610eba Differential Revision: [D60597463](https://our.internmc.facebook.com/intern/diff/D60597463) Pull Request resolved: https://github.com/pytorch/pytorch/pull/132356 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
37c3d503b7
commit
1962f9475f
@ -6,6 +6,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.utils.flop_counter
|
import torch.utils.flop_counter
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||||
@ -675,6 +676,53 @@ class TestFlopCounter(TestCase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@skipIfRocm # Nested tensor
|
||||||
|
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
|
||||||
|
@unittest.skipIf(
|
||||||
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||||
|
"Does not support all SDPA backends (pre-SM80 hardware on CUDA)",
|
||||||
|
)
|
||||||
|
def test_nested_attention_fake_tensors(self):
|
||||||
|
x = torch.randn(123, 4, 16, device="cuda", dtype=torch.bfloat16)
|
||||||
|
offsets = torch.tensor([0, 30, 60, 90, 123], device="cuda")
|
||||||
|
max_seqlen = 40
|
||||||
|
with FakeTensorMode() as fake_mode:
|
||||||
|
fake_x = fake_mode.from_tensor(x)
|
||||||
|
fake_offsets = fake_mode.from_tensor(offsets)
|
||||||
|
|
||||||
|
with FlopCounterMode() as fake_flop_counter_mode:
|
||||||
|
torch.ops.aten._flash_attention_forward(
|
||||||
|
fake_x,
|
||||||
|
fake_x,
|
||||||
|
fake_x,
|
||||||
|
fake_offsets,
|
||||||
|
fake_offsets,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen,
|
||||||
|
0.0,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
dense_x = torch.randn(4, 40, 4, 16, dtype=torch.bfloat16, device="cuda").transpose(1, 2)
|
||||||
|
|
||||||
|
with FlopCounterMode() as real_flop_counter_mode:
|
||||||
|
torch.ops.aten._flash_attention_forward(
|
||||||
|
dense_x,
|
||||||
|
dense_x,
|
||||||
|
dense_x,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen,
|
||||||
|
0.0,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(int(get_total_flops(fake_flop_counter_mode)), int(get_total_flops(real_flop_counter_mode)))
|
||||||
|
|
||||||
|
|
||||||
def test_addmm_out(self):
|
def test_addmm_out(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
y = torch.zeros(10, 10)
|
y = torch.zeros(10, 10)
|
||||||
|
@ -265,6 +265,18 @@ def sdpa_flop(query_shape, key_shape, value_shape, *args, out_shape=None, **kwar
|
|||||||
return sdpa_flop_count(query_shape, key_shape, value_shape)
|
return sdpa_flop_count(query_shape, key_shape, value_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _offsets_to_lengths(offsets, max_len):
|
||||||
|
"""
|
||||||
|
If the offsets tensor is fake, then we don't know the actual lengths.
|
||||||
|
In that case, we can just assume the worst case; each batch has max length.
|
||||||
|
"""
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensor
|
||||||
|
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||||
|
if not isinstance(offsets, (FakeTensor, FunctionalTensor)):
|
||||||
|
return offsets.diff().tolist()
|
||||||
|
return [max_len] * (offsets.size(0) - 1)
|
||||||
|
|
||||||
|
|
||||||
def _unpack_flash_attention_nested_shapes(
|
def _unpack_flash_attention_nested_shapes(
|
||||||
*,
|
*,
|
||||||
query,
|
query,
|
||||||
@ -298,8 +310,8 @@ def _unpack_flash_attention_nested_shapes(
|
|||||||
assert cum_seq_q is not None
|
assert cum_seq_q is not None
|
||||||
assert cum_seq_k is not None
|
assert cum_seq_k is not None
|
||||||
assert cum_seq_q.shape == cum_seq_k.shape
|
assert cum_seq_q.shape == cum_seq_k.shape
|
||||||
seq_q_lengths = (cum_seq_q[1:] - cum_seq_q[:-1]).tolist()
|
seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
|
||||||
seq_k_lengths = (cum_seq_k[1:] - cum_seq_k[:-1]).tolist()
|
seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
|
||||||
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
|
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
|
||||||
new_query_shape = (1, h_q, seq_q_len, d_q)
|
new_query_shape = (1, h_q, seq_q_len, d_q)
|
||||||
new_key_shape = (1, h_k, seq_k_len, d_k)
|
new_key_shape = (1, h_k, seq_k_len, d_k)
|
||||||
@ -346,8 +358,8 @@ def _unpack_efficient_attention_nested_shapes(
|
|||||||
assert cu_seqlens_q is not None
|
assert cu_seqlens_q is not None
|
||||||
assert cu_seqlens_k is not None
|
assert cu_seqlens_k is not None
|
||||||
assert cu_seqlens_q.shape == cu_seqlens_k.shape
|
assert cu_seqlens_q.shape == cu_seqlens_k.shape
|
||||||
seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).tolist()
|
seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
|
||||||
seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).tolist()
|
seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
|
||||||
for len_q, len_k in zip(seqlens_q, seqlens_k):
|
for len_q, len_k in zip(seqlens_q, seqlens_k):
|
||||||
new_query_shape = (1, h_q, len_q, d_q)
|
new_query_shape = (1, h_q, len_q, d_q)
|
||||||
new_key_shape = (1, h_k, len_k, d_k)
|
new_key_shape = (1, h_k, len_k, d_k)
|
||||||
|
Reference in New Issue
Block a user