autodiff changes to enable profiling

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25397

Differential Revision: D17565747

Pulled By: Krovatkin

fbshipit-source-id: b772437d9e02df99db6e662cb7d1227359959bed
This commit is contained in:
Nikolay Korovaiko
2019-09-25 10:10:10 -07:00
committed by Facebook Github Bot
parent 0cb10d7ebf
commit db5791d543
13 changed files with 183 additions and 79 deletions

View File

@ -22,7 +22,7 @@ class TestFuser(JitTestCase):
def assertAllFused(self, graph, except_for=()):
if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
graph = next(graph.nodes()).g('Subgraph')
allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::TupleConstruct'} | set(except_for)
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
'got {}'.format(graph))
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
@ -878,8 +878,9 @@ class TestFuser(JitTestCase):
assert backward is None
backward = g
old_plans.add(str(backward))
self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "aten::_grad_sum_to_size"]), i)
self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "prim::Param"]), 3 - i)
self.assertEqual(len([1 for o in next(backward.outputs()).node().inputs()
if o.node().kind() == "aten::_grad_sum_to_size"]), i)
self.assertEqual(len([1 for o in next(backward.outputs()).node().inputs() if o.node().kind() == "prim::Param"]), 3 - i)
if __name__ == '__main__':