mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	I feel it's easier to open a new PR rather than iterating on the previous PR (https://github.com/pytorch/pytorch/pull/105257 ) since this is more like a rewrite. In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR. The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time. By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117911 Approved by: https://github.com/jansel
		
			
				
	
	
		
			145 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			145 lines
		
	
	
		
			4.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: dynamo"]
 | |
| import operator
 | |
| 
 | |
| import torch
 | |
| 
 | |
| import torch._dynamo
 | |
| import torch._dynamo.config as config
 | |
| import torch._dynamo.test_case
 | |
| from torch._dynamo.testing import same
 | |
| from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module
 | |
| 
 | |
| 
 | |
| class Seq(torch.nn.Module):
 | |
|     def __init__(self):
 | |
|         super().__init__()
 | |
|         self.layers = torch.nn.Sequential(
 | |
|             torch.nn.Linear(10, 10),
 | |
|             torch.nn.ReLU(),
 | |
|             torch.nn.Linear(10, 10),
 | |
|             torch.nn.Sigmoid(),
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return self.layers(x)
 | |
| 
 | |
| 
 | |
| class Conv_Bn_Relu(torch.nn.Module):
 | |
|     def __init__(self, in_channels, out_channels, **kwargs):
 | |
|         super().__init__()
 | |
|         self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
 | |
|         self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
 | |
|         self.relu = torch.nn.ReLU()
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return self.relu(self.bn(self.conv(x)))
 | |
| 
 | |
| 
 | |
| def toy_example(a, b):
 | |
|     x = a / (torch.abs(a) + 1)
 | |
|     if b.sum() < 0:
 | |
|         b = b * -1
 | |
|     return x * b
 | |
| 
 | |
| 
 | |
| def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
 | |
|     for node in gm.graph.nodes:
 | |
|         # Checks if we're calling a function (i.e:
 | |
|         # operator.add)
 | |
|         if node.op == "call_function":
 | |
|             # The target attribute is the function
 | |
|             # that call_function calls.
 | |
|             if node.target == operator.mul:
 | |
|                 node.target = operator.add
 | |
| 
 | |
|     gm.graph.lint()  # Does some checks to make sure the
 | |
|     # Graph is well-formed.
 | |
| 
 | |
|     gm.recompile()
 | |
|     return gm
 | |
| 
 | |
| 
 | |
| @config.patch("verify_correctness", True)
 | |
| class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
 | |
|     def test_example_inputs(self):
 | |
|         def fn(a, bc, d):
 | |
|             b, c = bc
 | |
|             return a / d - b / c
 | |
| 
 | |
|         def compiler_fn(graph, example_inputs):
 | |
|             nonlocal r1
 | |
|             r1 = graph(*example_inputs)[0]
 | |
|             return graph.forward
 | |
| 
 | |
|         a = torch.empty(2).fill_(1)
 | |
|         b = torch.empty(2).fill_(2)
 | |
|         c = torch.empty(2).fill_(3)
 | |
|         d = 4
 | |
|         r1 = None
 | |
|         r2 = fn(a, (b, c), d)
 | |
|         opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
 | |
|         r3 = opt_fn(a, (b, c), d)
 | |
| 
 | |
|         self.assertIsNotNone(r1)
 | |
| 
 | |
|         self.assertEqual(r1.shape, r2.shape)
 | |
|         self.assertEqual(r1.shape, r3.shape)
 | |
|         self.assertEqual(r1.device, r2.device)
 | |
|         self.assertEqual(r1.device, r3.device)
 | |
| 
 | |
|     @_force_skip_lazy_graph_module()
 | |
|     def test_torchscript(self):
 | |
|         s = Seq()
 | |
|         i = torch.randn(10)
 | |
|         r1 = s(i)
 | |
|         opt_s = torch._dynamo.optimize("ts")(s)
 | |
|         r2 = opt_s(i)
 | |
|         self.assertTrue(same(r1, r2))
 | |
| 
 | |
|     def test_incorrect_verify_true(self):
 | |
|         """
 | |
|         If a bad optimization return a graph that
 | |
|         is not functionally equal to the original graph;
 | |
|         When config.verify_correctness=True, it will
 | |
|         check the correctness of outputs and raise an error
 | |
|         """
 | |
|         i1 = torch.randn(10)
 | |
|         i2 = torch.randn(10)
 | |
| 
 | |
|         def incorrect_compile_fn(gm, example_inputs):
 | |
|             return transform(gm).forward
 | |
| 
 | |
|         toy_example(i1, i2)
 | |
|         try:
 | |
|             opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
 | |
|             opt_toy_example(i1, i2)
 | |
|         except RuntimeError:
 | |
|             pass
 | |
|         else:
 | |
|             self.fail("expected failure")
 | |
| 
 | |
|     @config.patch("verify_correctness", False)
 | |
|     def test_incorrect_verify_false(self):
 | |
|         """
 | |
|         The bad optimization return a graph that
 | |
|         is not functionally equal to the original graph;
 | |
|         When config.verify_correctness=False, wrong outputs
 | |
|         will return
 | |
|         """
 | |
|         i1 = torch.randn(10)
 | |
|         i2 = torch.randn(10)
 | |
| 
 | |
|         def incorrect_compile_fn(gm, example_inputs):
 | |
|             return transform(gm).forward
 | |
| 
 | |
|         r1 = toy_example(i1, i2)
 | |
|         opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
 | |
|         r2 = opt_toy_example(i1, i2)
 | |
|         self.assertTrue(not same(r1, r2))
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     from torch._dynamo.test_case import run_tests
 | |
| 
 | |
|     run_tests()
 |