mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
related commits: - #139706 - #140238 - #140247 - #140253 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140238 Approved by: https://github.com/soulitzer
144 lines
4.1 KiB
Python
144 lines
4.1 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import operator
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.config as config
|
|
import torch._dynamo.test_case
|
|
from torch._dynamo.testing import same
|
|
from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module
|
|
|
|
|
|
class Seq(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layers = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.Sigmoid(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.layers(x)
|
|
|
|
|
|
class Conv_Bn_Relu(torch.nn.Module):
|
|
def __init__(self, in_channels, out_channels, **kwargs):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
|
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.bn(self.conv(x)))
|
|
|
|
|
|
def toy_example(a, b):
|
|
x = a / (torch.abs(a) + 1)
|
|
if b.sum() < 0:
|
|
b = b * -1
|
|
return x * b
|
|
|
|
|
|
def transform(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
|
for node in gm.graph.nodes:
|
|
# Checks if we're calling a function (i.e:
|
|
# operator.add)
|
|
if node.op == "call_function":
|
|
# The target attribute is the function
|
|
# that call_function calls.
|
|
if node.target == operator.mul:
|
|
node.target = operator.add
|
|
|
|
gm.graph.lint() # Does some checks to make sure the
|
|
# Graph is well-formed.
|
|
|
|
gm.recompile()
|
|
return gm
|
|
|
|
|
|
@config.patch("verify_correctness", True)
|
|
class TestVerifyCorrectness(torch._dynamo.test_case.TestCase):
|
|
def test_example_inputs(self):
|
|
def fn(a, bc, d):
|
|
b, c = bc
|
|
return a / d - b / c
|
|
|
|
def compiler_fn(graph, example_inputs):
|
|
nonlocal r1
|
|
r1 = graph(*example_inputs)[0]
|
|
return graph.forward
|
|
|
|
a = torch.empty(2).fill_(1)
|
|
b = torch.empty(2).fill_(2)
|
|
c = torch.empty(2).fill_(3)
|
|
d = 4
|
|
r1 = None
|
|
r2 = fn(a, (b, c), d)
|
|
opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
|
|
r3 = opt_fn(a, (b, c), d)
|
|
|
|
self.assertIsNotNone(r1)
|
|
|
|
self.assertEqual(r1.shape, r2.shape)
|
|
self.assertEqual(r1.shape, r3.shape)
|
|
self.assertEqual(r1.device, r2.device)
|
|
self.assertEqual(r1.device, r3.device)
|
|
|
|
@_force_skip_lazy_graph_module()
|
|
def test_torchscript(self):
|
|
s = Seq()
|
|
i = torch.randn(10)
|
|
r1 = s(i)
|
|
opt_s = torch.compile(s, backend="ts")
|
|
r2 = opt_s(i)
|
|
self.assertTrue(same(r1, r2))
|
|
|
|
def test_incorrect_verify_true(self):
|
|
"""
|
|
If a bad optimization return a graph that
|
|
is not functionally equal to the original graph;
|
|
When config.verify_correctness=True, it will
|
|
check the correctness of outputs and raise an error
|
|
"""
|
|
i1 = torch.randn(10)
|
|
i2 = torch.randn(10)
|
|
|
|
def incorrect_compile_fn(gm, example_inputs):
|
|
return transform(gm).forward
|
|
|
|
toy_example(i1, i2)
|
|
try:
|
|
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
|
|
opt_toy_example(i1, i2)
|
|
except RuntimeError:
|
|
pass
|
|
else:
|
|
self.fail("expected failure")
|
|
|
|
@config.patch("verify_correctness", False)
|
|
def test_incorrect_verify_false(self):
|
|
"""
|
|
The bad optimization return a graph that
|
|
is not functionally equal to the original graph;
|
|
When config.verify_correctness=False, wrong outputs
|
|
will return
|
|
"""
|
|
i1 = torch.randn(10)
|
|
i2 = torch.randn(10)
|
|
|
|
def incorrect_compile_fn(gm, example_inputs):
|
|
return transform(gm).forward
|
|
|
|
r1 = toy_example(i1, i2)
|
|
opt_toy_example = torch.compile(toy_example, backend=incorrect_compile_fn)
|
|
r2 = opt_toy_example(i1, i2)
|
|
self.assertTrue(not same(r1, r2))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|