mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
autodiff changes to enable profiling
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25397 Differential Revision: D17565747 Pulled By: Krovatkin fbshipit-source-id: b772437d9e02df99db6e662cb7d1227359959bed
This commit is contained in:
committed by
Facebook Github Bot
parent
0cb10d7ebf
commit
db5791d543
@ -22,7 +22,7 @@ class TestFuser(JitTestCase):
|
||||
def assertAllFused(self, graph, except_for=()):
|
||||
if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
|
||||
graph = next(graph.nodes()).g('Subgraph')
|
||||
allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
|
||||
allowed_nodes = {'prim::Constant', 'prim::FusionGroup', 'prim::TupleConstruct'} | set(except_for)
|
||||
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
|
||||
'got {}'.format(graph))
|
||||
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
|
||||
@ -878,8 +878,9 @@ class TestFuser(JitTestCase):
|
||||
assert backward is None
|
||||
backward = g
|
||||
old_plans.add(str(backward))
|
||||
self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "aten::_grad_sum_to_size"]), i)
|
||||
self.assertEqual(len([1 for o in backward.outputs() if o.node().kind() == "prim::Param"]), 3 - i)
|
||||
self.assertEqual(len([1 for o in next(backward.outputs()).node().inputs()
|
||||
if o.node().kind() == "aten::_grad_sum_to_size"]), i)
|
||||
self.assertEqual(len([1 for o in next(backward.outputs()).node().inputs() if o.node().kind() == "prim::Param"]), 3 - i)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user