Files
pytorch/test/dynamo/test_verify_correctness.py
Shunting Zhang fe10b1800f LazyGraphModule (#117911)
I feel it's easier to open a new PR rather than iterating on the previous PR (https://github.com/pytorch/pytorch/pull/105257 ) since this is more like a rewrite.

In this PR, instead of changing GraphModule directly which can easily causes BC issue, I create a LazyGraphModule class as Zachary & Jason suggested in comments from the previous PR.

The difference between LazyGraphModule and GraphModule is mainly about how re-compile for the graph module happens. In GraphModule the recompilation happens 'eagerly': constructing a GraphModule will cause the recompilation. While in LazyGraphModule, we just mark the module as needing recompilation. The real recompilation only happens when absolutely required (e.g. call forward method, access the code property etc.). In a lot of cases in torch.compile, the real recompilation eventually is not triggered at all. This can save a few seconds of compilation time.

By default, GraphModule rather than LazyGraphModule is used. `use_lazy_graph_module(True)` context manager can be used to pick LazyGraphModule instead. This has been applied to the torch.compile stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117911
Approved by: https://github.com/jansel
2024-01-27 04:10:18 +00:00

145 lines
4.1 KiB
Python

# Owner(s): ["module: dynamo"]
import operator
import torch
import torch._dynamo
import torch._dynamo.config as config
import torch._dynamo.test_case
from torch._dynamo.testing import same
from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module
class Seq(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.Sigmoid(),
)
def forward(self, x):
return self.layers(x)
class Conv_Bn_Relu(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
# Checks if we're calling a function (i.e:
# operator.add)
if node.op == "call_function":
# The target attribute is the function
# that call_function calls.
if node.target == operator.mul:
node.target = operator.add
gm.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
gm.recompile()
return gm
@config.patch("verify_correctness", True)
class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
def test_example_inputs(self):
def fn(a, bc, d):
b, c = bc
return a / d - b / c
def compiler_fn(graph, example_inputs):
nonlocal r1
r1 = graph(*example_inputs)[0]
return graph.forward
a = torch.empty(2).fill_(1)
b = torch.empty(2).fill_(2)
c = torch.empty(2).fill_(3)
d = 4
r1 = None
r2 = fn(a, (b, c), d)
opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
r3 = opt_fn(a, (b, c), d)
self.assertIsNotNone(r1)
self.assertEqual(r1.shape, r2.shape)
self.assertEqual(r1.shape, r3.shape)
self.assertEqual(r1.device, r2.device)
self.assertEqual(r1.device, r3.device)
@_force_skip_lazy_graph_module()
def test_torchscript(self):
s = Seq()
i = torch.randn(10)
r1 = s(i)
opt_s = torch._dynamo.optimize("ts")(s)
r2 = opt_s(i)
self.assertTrue(same(r1, r2))
def test_incorrect_verify_true(self):
"""
If a bad optimization return a graph that
is not functionally equal to the original graph;
When config.verify_correctness=True, it will
check the correctness of outputs and raise an error
"""
i1 = torch.randn(10)
i2 = torch.randn(10)
def incorrect_compile_fn(gm, example_inputs):
return transform(gm).forward
toy_example(i1, i2)
try:
opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
opt_toy_example(i1, i2)
except RuntimeError:
pass
else:
self.fail("expected failure")
@config.patch("verify_correctness", False)
def test_incorrect_verify_false(self):
"""
The bad optimization return a graph that
is not functionally equal to the original graph;
When config.verify_correctness=False, wrong outputs
will return
"""
i1 = torch.randn(10)
i2 = torch.randn(10)
def incorrect_compile_fn(gm, example_inputs):
return transform(gm).forward
r1 = toy_example(i1, i2)
opt_toy_example = torch._dynamo.optimize(incorrect_compile_fn)(toy_example)
r2 = opt_toy_example(i1, i2)
self.assertTrue(not same(r1, r2))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()