mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Updated flop counter to accept pytree inputs/outputs (#111990)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111990 Approved by: https://github.com/ezyang
This commit is contained in:
@ -247,6 +247,40 @@ class TestFlopCounter(TestCase):
|
||||
self.assertEqual(len(model._forward_pre_hooks), 0)
|
||||
self.assertEqual(len(model._forward_hooks), 0)
|
||||
|
||||
def test_pytrees(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x['a'].relu_()
|
||||
return {'a': torch.mm(x, x)}
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = Foo()
|
||||
self.b = Foo()
|
||||
|
||||
def forward(self, x):
|
||||
return self.b(self.a(x))
|
||||
|
||||
mod = Mod()
|
||||
mode = FlopCounterMode(mod)
|
||||
with mode:
|
||||
mod({'a': torch.randn(10, 10, requires_grad=True).clone()})['a'].sum().backward()
|
||||
self.assertExpectedInline((mode.flop_counts['Mod'][torch.ops.aten.mm]), """12000""")
|
||||
|
||||
class Mod2(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return (torch.mm(x, x),)
|
||||
|
||||
mod = Mod2()
|
||||
mode = FlopCounterMode(mod)
|
||||
with mode:
|
||||
mod(torch.randn(10, 10, requires_grad=True))[0].sum().backward()
|
||||
self.assertExpectedInline((mode.flop_counts['Mod2'][torch.ops.aten.mm]), """6000""")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from typing import List, Any, Dict, Optional, Union, NamedTuple
|
||||
from collections import defaultdict
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
@ -248,13 +248,24 @@ def convert_to_percent_str(num, denom):
|
||||
return "0%"
|
||||
return f"{num / denom:.2%}"
|
||||
|
||||
def _pytreeify_preserve_structure(f):
|
||||
@wraps(f)
|
||||
def nf(args):
|
||||
flat_args, spec = tree_flatten(args)
|
||||
out = f(*flat_args)
|
||||
return tree_unflatten(out, spec)
|
||||
|
||||
return nf
|
||||
|
||||
|
||||
class FlopCounterMode(TorchDispatchMode):
|
||||
"""
|
||||
``FlopCounterMode`` is a context manager that counts the number of
|
||||
flops within its context. It does this using a ``TorchDispatchMode``.
|
||||
|
||||
It also supports hierarchical output by passing a module (or list of modules) to FlopCounterMode on construction.
|
||||
It also supports hierarchical output by passing a module (or list of
|
||||
modules) to FlopCounterMode on construction. If you do not need hierarchical
|
||||
output, you do not need to use it with a module.
|
||||
|
||||
Example usage
|
||||
|
||||
@ -298,6 +309,7 @@ class FlopCounterMode(TorchDispatchMode):
|
||||
name = prefix
|
||||
else:
|
||||
name = ".".join([prefix, name])
|
||||
|
||||
forward_pre_hook_handle = module.register_forward_pre_hook(self._enter_module(name))
|
||||
forward_hook_handle = module.register_forward_hook(self._exit_module(name))
|
||||
self._module_to_forward_hook_handles[module] = _ForwardHookHandles(
|
||||
@ -312,16 +324,15 @@ class FlopCounterMode(TorchDispatchMode):
|
||||
|
||||
def _enter_module(self, name):
|
||||
def f(module, inputs):
|
||||
inputs = normalize_tuple(inputs)
|
||||
out = self._create_pre_module(name)(*inputs)
|
||||
out = _pytreeify_preserve_structure(self._create_pre_module(name))(inputs)
|
||||
return out
|
||||
|
||||
return f
|
||||
|
||||
def _exit_module(self, name):
|
||||
def f(module, inputs, outputs):
|
||||
outputs = normalize_tuple(outputs)
|
||||
return self._create_post_module(name)(*outputs)
|
||||
outputs = _pytreeify_preserve_structure(self._create_post_module(name))(outputs)
|
||||
return outputs
|
||||
return f
|
||||
|
||||
def _create_post_module(self, name):
|
||||
@ -331,8 +342,6 @@ class FlopCounterMode(TorchDispatchMode):
|
||||
assert(self.parents[-1] == name)
|
||||
self.parents.pop()
|
||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
@ -348,8 +357,6 @@ class FlopCounterMode(TorchDispatchMode):
|
||||
def forward(ctx, *args):
|
||||
self.parents.append(name)
|
||||
args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user