mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NJT] fix flop counter for SDPA & test (#147032)
Fixes 3 issues: 1. The test wasn't actually testing SDPA: both were checking cuda, and the inputs to SDPA were not transposed. 2. FlopCounterMode has been renamed _FlopCounterMode (and a wrapper named FlopCounterMode has been added) 3. offsets_to_list also needs to ignore the actual offset values if offsets is a meta tensor. Differential Revision: [D69558785](https://our.internmc.facebook.com/intern/diff/D69558785) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147032 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
b9a22b3f37
commit
43496e9b90
@ -7021,13 +7021,17 @@ torch.cuda.synchronize()
|
||||
(8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16
|
||||
)
|
||||
offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
|
||||
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16)
|
||||
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
values_meta = torch.randn(
|
||||
(8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16
|
||||
)
|
||||
offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32)
|
||||
nt_meta = convert_jagged_to_nested_tensor(values, offsets, max_length=16)
|
||||
nt_meta = convert_jagged_to_nested_tensor(
|
||||
values_meta, offsets_meta, max_length=16
|
||||
).transpose(1, 2)
|
||||
|
||||
self.assertEqual(get_flops(nt), get_flops(nt_meta))
|
||||
|
||||
|
@ -623,7 +623,7 @@ def _is_computing_meta_flops(x):
|
||||
torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
||||
)
|
||||
return any(
|
||||
type(x) == torch.utils.flop_counter.FlopCounterMode
|
||||
type(x) == torch.utils.flop_counter._FlopCounterMode
|
||||
for x in torch_dispatch_mode_stack
|
||||
)
|
||||
return False
|
||||
|
@ -290,7 +290,7 @@ def _offsets_to_lengths(offsets, max_len):
|
||||
"""
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
if not isinstance(offsets, (FakeTensor, FunctionalTensor)):
|
||||
if not isinstance(offsets, (FakeTensor, FunctionalTensor)) and offsets.device.type != "meta":
|
||||
return offsets.diff().tolist()
|
||||
return [max_len] * (offsets.size(0) - 1)
|
||||
|
||||
|
Reference in New Issue
Block a user