[NJT] fix flop counter for SDPA & test (#147032)

Fixes 3 issues:
1. The test wasn't actually testing SDPA: both were checking cuda, and the inputs to SDPA were not transposed.
2. FlopCounterMode has been renamed _FlopCounterMode (and a wrapper named FlopCounterMode has been added)
3. offsets_to_list also needs to ignore the actual offset values if offsets is a meta tensor.

Differential Revision: [D69558785](https://our.internmc.facebook.com/intern/diff/D69558785)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147032
Approved by: https://github.com/jbschlosser
This commit is contained in:
David Berard
2025-02-12 15:18:11 -08:00
committed by PyTorch MergeBot
parent b9a22b3f37
commit 43496e9b90
3 changed files with 8 additions and 4 deletions

View File

@ -7021,13 +7021,17 @@ torch.cuda.synchronize()
(8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16
)
offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16)
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16).transpose(
1, 2
)
values_meta = torch.randn(
(8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16
)
offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32)
nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16)
nt_meta = convert_jagged_to_nested_tensor(
values_meta, offsets_meta, max_length=16
).transpose(1, 2)
self.assertEqual(get_flops(nt), get_flops(nt_meta))

View File

@ -623,7 +623,7 @@ def _is_computing_meta_flops(x):
torch.utils._python_dispatch._get_current_dispatch_mode_stack()
)
return any(
type(x) == torch.utils.flop_counter.FlopCounterMode
type(x) == torch.utils.flop_counter._FlopCounterMode
for x in torch_dispatch_mode_stack
)
return False

View File

@ -290,7 +290,7 @@ def _offsets_to_lengths(offsets, max_len):
"""
from torch._subclasses.fake_tensor import FakeTensor
from torch._subclasses.functional_tensor import FunctionalTensor
if not isinstance(offsets, (FakeTensor, FunctionalTensor)):
if not isinstance(offsets, (FakeTensor, FunctionalTensor)) and offsets.device.type != "meta":
return offsets.diff().tolist()
return [max_len] * (offsets.size(0) - 1)