Updated flop counter to accept pytree inputs/outputs (#111990)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111990
Approved by: https://github.com/ezyang
This commit is contained in:
chilli
2023-10-25 10:52:36 -07:00
committed by PyTorch MergeBot
parent d641450180
commit 74adb4cccc
2 changed files with 51 additions and 10 deletions

View File

@ -247,6 +247,40 @@ class TestFlopCounter(TestCase):
self.assertEqual(len(model._forward_pre_hooks), 0)
self.assertEqual(len(model._forward_hooks), 0)
def test_pytrees(self):
class Foo(torch.nn.Module):
def forward(self, x):
x = x['a'].relu_()
return {'a': torch.mm(x, x)}
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = Foo()
self.b = Foo()
def forward(self, x):
return self.b(self.a(x))
mod = Mod()
mode = FlopCounterMode(mod)
with mode:
mod({'a': torch.randn(10, 10, requires_grad=True).clone()})['a'].sum().backward()
self.assertExpectedInline((mode.flop_counts['Mod'][torch.ops.aten.mm]), """12000""")
class Mod2(torch.nn.Module):
def forward(self, x):
return (torch.mm(x, x),)
mod = Mod2()
mode = FlopCounterMode(mod)
with mode:
mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
self.assertExpectedInline((mode.flop_counts['Mod2'][torch.ops.aten.mm]), """6000""")
if __name__ == '__main__':
run_tests()

View File

@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from typing import List, Any, Dict, Optional, Union, NamedTuple
from collections import defaultdict
from torch.utils._python_dispatch import TorchDispatchMode
@ -248,13 +248,24 @@ def convert_to_percent_str(num, denom):
return "0%"
return f"{num / denom:.2%}"
def _pytreeify_preserve_structure(f):
@wraps(f)
def nf(args):
flat_args, spec = tree_flatten(args)
out = f(*flat_args)
return tree_unflatten(out, spec)
return nf
class FlopCounterMode(TorchDispatchMode):
"""
``FlopCounterMode`` is a context manager that counts the number of
flops within its context. It does this using a ``TorchDispatchMode``.
It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
It also supports hierarchical output by passing a module (or list of
modules) to FlopCounterMode on construction. If you do not need hierarchical
output, you do not need to use it with a module.
Example usage
@ -298,6 +309,7 @@ class FlopCounterMode(TorchDispatchMode):
name = prefix
else:
name = ".".join([prefix, name])
forward_pre_hook_handle = module.register_forward_pre_hook(self._enter_module(name))
forward_hook_handle = module.register_forward_hook(self._exit_module(name))
self._module_to_forward_hook_handles[module] = _ForwardHookHandles(
@ -312,16 +324,15 @@ class FlopCounterMode(TorchDispatchMode):
def _enter_module(self, name):
def f(module, inputs):
inputs = normalize_tuple(inputs)
out = self._create_pre_module(name)(*inputs)
out = _pytreeify_preserve_structure(self._create_pre_module(name))(inputs)
return out
return f
def _exit_module(self, name):
def f(module, inputs, outputs):
outputs = normalize_tuple(outputs)
return self._create_post_module(name)(*outputs)
outputs = _pytreeify_preserve_structure(self._create_post_module(name))(outputs)
return outputs
return f
def _create_post_module(self, name):
@ -331,8 +342,6 @@ class FlopCounterMode(TorchDispatchMode):
assert(self.parents[-1] == name)
self.parents.pop()
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod
@ -348,8 +357,6 @@ class FlopCounterMode(TorchDispatchMode):
def forward(ctx, *args):
self.parents.append(name)
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
if len(args) == 1:
return args[0]
return args
@staticmethod