diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index f8cf07d2fcfa..b17e539535d9 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -247,6 +247,40 @@ class TestFlopCounter(TestCase): self.assertEqual(len(model._forward_pre_hooks), 0) self.assertEqual(len(model._forward_hooks), 0) + def test_pytrees(self): + class Foo(torch.nn.Module): + def forward(self, x): + x = x['a'].relu_() + return {'a': torch.mm(x, x)} + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = Foo() + self.b = Foo() + + def forward(self, x): + return self.b(self.a(x)) + + mod = Mod() + mode = FlopCounterMode(mod) + with mode: + mod({'a': torch.randn(10, 10, requires_grad=True).clone()})['a'].sum().backward() + self.assertExpectedInline((mode.flop_counts['Mod'][torch.ops.aten.mm]), """12000""") + + class Mod2(torch.nn.Module): + def forward(self, x): + return (torch.mm(x, x),) + + mod = Mod2() + mode = FlopCounterMode(mod) + with mode: + mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward() + self.assertExpectedInline((mode.flop_counts['Mod2'][torch.ops.aten.mm]), """6000""") + + + + if __name__ == '__main__': run_tests() diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index c7dc1738ffe8..27319eb939f1 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from typing import List, Any, Dict, Optional, Union, NamedTuple from collections import defaultdict from torch.utils._python_dispatch import TorchDispatchMode @@ -248,13 +248,24 @@ def convert_to_percent_str(num, denom): return "0%" return f"{num / denom:.2%}" +def _pytreeify_preserve_structure(f): + @wraps(f) + def nf(args): + flat_args, spec = tree_flatten(args) + out = f(*flat_args) + return tree_unflatten(out, spec) + + return nf + class FlopCounterMode(TorchDispatchMode): """ ``FlopCounterMode`` is a context manager that counts the number of flops within its context. It does this using a ``TorchDispatchMode``. - It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction. + It also supports hierarchical output by passing a module (or list of + modules) to FlopCounterMode on construction. If you do not need hierarchical + output, you do not need to use it with a module. Example usage @@ -298,6 +309,7 @@ class FlopCounterMode(TorchDispatchMode): name = prefix else: name = ".".join([prefix, name]) + forward_pre_hook_handle = module.register_forward_pre_hook(self._enter_module(name)) forward_hook_handle = module.register_forward_hook(self._exit_module(name)) self._module_to_forward_hook_handles[module] = _ForwardHookHandles( @@ -312,16 +324,15 @@ class FlopCounterMode(TorchDispatchMode): def _enter_module(self, name): def f(module, inputs): - inputs = normalize_tuple(inputs) - out = self._create_pre_module(name)(*inputs) + out = _pytreeify_preserve_structure(self._create_pre_module(name))(inputs) return out return f def _exit_module(self, name): def f(module, inputs, outputs): - outputs = normalize_tuple(outputs) - return self._create_post_module(name)(*outputs) + outputs = _pytreeify_preserve_structure(self._create_post_module(name))(outputs) + return outputs return f def _create_post_module(self, name): @@ -331,8 +342,6 @@ class FlopCounterMode(TorchDispatchMode): assert(self.parents[-1] == name) self.parents.pop() args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) - if len(args) == 1: - return args[0] return args @staticmethod @@ -348,8 +357,6 @@ class FlopCounterMode(TorchDispatchMode): def forward(ctx, *args): self.parents.append(name) args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) - if len(args) == 1: - return args[0] return args @staticmethod