Files
pytorch/test/dynamo/test_verify_correctness.py

144 lines
4.1 KiB
Python

# Owner(s): ["module: dynamo"]
import operator
import torch
import torch._dynamo
import torch._dynamo.config as config
import torch._dynamo.test_case
from torch._dynamo.testing import same
from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module
class Seq(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.Sigmoid(),
)
def forward(self, x):
return self.layers(x)
class Conv_Bn_Relu(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
# Checks if we're calling a function (i.e:
# operator.add)
if node.op == "call_function":
# The target attribute is the function
# that call_function calls.
if node.target == operator.mul:
node.target = operator.add
gm.graph.lint() # Does some checks to make sure the
# Graph is well-formed.
gm.recompile()
return gm
@config.patch("verify_correctness", True)
class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
def test_example_inputs(self):
def fn(a, bc, d):
b, c = bc
return a / d - b / c
def compiler_fn(graph, example_inputs):
nonlocal r1
r1 = graph(*example_inputs)[0]
return graph.forward
a = torch.empty(2).fill_(1)
b = torch.empty(2).fill_(2)
c = torch.empty(2).fill_(3)
d = 4
r1 = None
r2 = fn(a, (b, c), d)
opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
r3 = opt_fn(a, (b, c), d)
self.assertIsNotNone(r1)
self.assertEqual(r1.shape, r2.shape)
self.assertEqual(r1.shape, r3.shape)
self.assertEqual(r1.device, r2.device)
self.assertEqual(r1.device, r3.device)
@_force_skip_lazy_graph_module()
def test_torchscript(self):
s = Seq()
i = torch.randn(10)
r1 = s(i)
opt_s = torch.compile(s, backend="ts")
r2 = opt_s(i)
self.assertTrue(same(r1, r2))
def test_incorrect_verify_true(self):
"""
If a bad optimization return a graph that
is not functionally equal to the original graph;
When config.verify_correctness=True, it will
check the correctness of outputs and raise an error
"""
i1 = torch.randn(10)
i2 = torch.randn(10)
def incorrect_compile_fn(gm, example_inputs):
return transform(gm).forward
toy_example(i1, i2)
try:
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
opt_toy_example(i1, i2)
except RuntimeError:
pass
else:
self.fail("expected failure")
@config.patch("verify_correctness", False)
def test_incorrect_verify_false(self):
"""
The bad optimization return a graph that
is not functionally equal to the original graph;
When config.verify_correctness=False, wrong outputs
will return
"""
i1 = torch.randn(10)
i2 = torch.randn(10)
def incorrect_compile_fn(gm, example_inputs):
return transform(gm).forward
r1 = toy_example(i1, i2)
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
r2 = opt_toy_example(i1, i2)
self.assertTrue(not same(r1, r2))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()