Files
pytorch/test/dynamo/test_export.py

1430 lines
46 KiB
Python

# Owner(s): ["module: dynamo"]
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.utils._pytree as pytree
from torch.fx.experimental.proxy_tensor import make_fx
class ExportTests(torch._dynamo.test_case.TestCase):
# TODO(voz): Refactor to a shared test function.
# The tests in this file are a little redundant,
# They all take a func, run it with eager, then export it, then compare
def test_export(self):
def pre_attention_state_ops(input, mems, state):
lc_key = state[0]
lc_val = state[1]
bar = []
for i in range(0, 4):
bar2 = []
for j in range(0, 3):
bar2.append(
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
)
bar.append(bar2)
return bar
def func():
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
state = [
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
]
i = torch.tensor(
[
[0.0313, -0.1487, -0.3846, -0.5321],
[-1.7073, 1.3331, -0.0890, -1.4935],
[-0.8314, -0.1862, -0.5935, 1.5232],
]
)
return pre_attention_state_ops(i, mems, state)
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func()
torch._dynamo.reset()
exported = torch._dynamo.export(func)
out_graph = exported[0]
dynamo_result = out_graph()
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(func, torch.tensor([[[1.3737, 0.1]]]))
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_bypass(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return first * second
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_list_unpack(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return x[0], first * second, x[1], x[2]
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out_2(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(func, torch.tensor([[[1.3737, 0.1]]]))
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_list(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[2]
second = x[2]
return first * second, x
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_complex_reorder(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[0]
second = x[1]
third = x[2]
return third, first, second, first * second, first * third
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_2(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.4, 0.4])
inps = [inp, inp2]
def func(x, z):
y = x + 1
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_non_tensor_arg(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_reorder_with_non_tensor_arg(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return z, y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_dupes_and_bypass_with_non_tensor_output(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y[0].item(), y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_zeroes_in_and_out_different_shape_on_test(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return [[a], [b, c], [a + b], [[c + c]]]
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_zeroes_in_new_shape_scalar_out(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return a[0].item() + b[0].item() + c[0].item()
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_zeroes_in_new_shape_scalar_out_permute(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return b[0].item() + c[0].item() + a[0].item() + a[0].item()
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_func_return(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
def func2(y):
return x * y
return func2(x)
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dict_return(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
return {"a": x}
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_aten_graph(self):
def pre_attention_state_ops(input, mems, state):
lc_key = state[0]
lc_val = state[1]
bar = []
for i in range(0, 4):
bar2 = []
for j in range(0, 3):
bar2.append(
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
)
bar.append(bar2)
return bar
def func():
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
state = [
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
]
i = torch.tensor(
[
[0.0313, -0.1487, -0.3846, -0.5321],
[-1.7073, 1.3331, -0.0890, -1.4935],
[-0.8314, -0.1862, -0.5935, 1.5232],
]
)
return pre_attention_state_ops(i, mems, state)
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func()
torch._dynamo.reset()
exported = torch._dynamo.export(func, aten_graph=True)
out_graph = exported[0]
dynamo_result = out_graph()
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out_with_aten_graph(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(
func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True
)
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_bypass_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return first * second
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_list_unpack_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
]
def func(x):
first = x[2]
second = x[2]
return x[0], first * second, x[1], x[2]
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_mismatched_out_2_with_aten_graph(self):
def func(x):
y = x + 1
return ([x, x], (y, y))
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
torch._dynamo.reset()
exported = torch._dynamo.export(
func, torch.tensor([[[1.3737, 0.1]]]), aten_graph=True
)
out_graph = exported[0]
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_list_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[2]
second = x[2]
return first * second, x
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_graph_with_complex_reorder_with_aten_graph(self):
inp = [
torch.tensor([0.1, 0.1]),
torch.tensor([0.2, 0.2]),
torch.tensor([0.3, 0.3]),
torch.tensor([0.4, 0.4]),
]
def func(x):
first = x[0]
second = x[1]
third = x[2]
return third, first, second, first * second, first * third
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_2_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
def func(x):
y = x + 1
return y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(inp)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inp)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.4, 0.4])
inps = [inp, inp2]
def func(x, z):
y = x + 1
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y, y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return z, y, y
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
inp = torch.tensor([0.1, 0.1])
inp2 = torch.tensor([0.1, 0.1])
inp3 = 4
inps = [inp, inp2, inp3]
def func(x, z, k):
y = x + k
return y[0].item(), y, z
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
return [[a], [b, c], [a + b], [[c + c]]]
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_func_return_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
def func2(y):
return x * y
return func2(x)
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_dict_return_with_aten_graph(self):
inp = torch.zeros(10)
inp2 = torch.zeros(10)
inp3 = torch.zeros(10)
inps = [inp, inp2, inp3]
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
def func(a, b, c):
x = a + b + c
return {"a": x}
opt_func = torch._dynamo.optimize("eager", nopython=True)(func)
real_result = opt_func(*inps_rand)
torch._dynamo.reset()
exported = torch._dynamo.export(func, *inps, aten_graph=True)
out_graph = exported[0]
flat_input, _ = pytree.tree_flatten(inps_rand)
dynamo_result = out_graph(*flat_input)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
def test_export_with_stack_trace(self):
inp = torch.tensor([0.1, 0.1])
linear = torch.nn.Linear(2, 2)
def func(x):
x = x + 1
y = x.t()
y = y.relu()
y = linear(y)
return y
exported = torch._dynamo.export(func, inp, aten_graph=False)
out_graph = exported[0]
for node in out_graph.graph.nodes:
if node.op not in {"placeholder", "output"}:
self.assertTrue(node.stack_trace is not None)
torch._dynamo.reset()
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
for node in out_graph.graph.nodes:
if node.op == "call_function":
self.assertTrue(node.stack_trace is not None)
def test_export_compare_optimize_with_make_fx(self):
inp = torch.tensor([0.1, 0.1])
linear = torch.nn.Linear(2, 2)
def func(x):
x = x + 1
y = x.t()
y = y.relu()
y = linear(y)
return y
exported = torch._dynamo.export(func, inp, aten_graph=True)
out_graph = exported[0]
export_result = out_graph(inp)
torch._dynamo.reset()
def compiler(gm, sample_inputs):
aten_gm = make_fx(gm)(*sample_inputs)
self.assertEqual(len(aten_gm.graph.nodes), len(out_graph.graph.nodes))
for node1, node2 in zip(aten_gm.graph.nodes, out_graph.graph.nodes):
self.assertEqual(node1.op, node2.op)
if node1.op == "call_function":
self.assertEqual(node1.target, node2.target)
self.assertEqual(len(node1.args), len(node2.args))
for arg1, arg2 in zip(node1.args, node2.args):
self.assertEqual(type(arg1), type(arg2))
return aten_gm.forward
opt_func = torch._dynamo.optimize(compiler, nopython=True)(func)
make_fx_result = opt_func(inp)
self.assertTrue(torch._dynamo.utils.same(make_fx_result, export_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_method_on_module(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_method_on_module_invoke_twice(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = self.helper_fn(x) + self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_free_function(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return torch.nonzero(x)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x) + self.helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_free_function_and_class_method(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x)
return y
module = MyModule()
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
module = MyModule()
graph, _ = torch._dynamo.export(module, torch.tensor([[0.0, 0], [0, 0]]))
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_free_function_and_class_method_multiarg(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 2))
self.linear = torch.nn.Linear(2, 2)
def forward(self, x, z):
y = torch.sin(x)
x = self.linear(x)
y = helper_fn(x) + helper_fn(z)
return y
module = MyModule()
real_result = module(
torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
module = MyModule()
graph, _ = torch._dynamo.export(
module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
result = graph(
torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(
torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_free_function_and_class_method_multiarg_diff(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return torch.nonzero(x)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, z):
y = helper_fn(x) + helper_fn(z)
return y
module = MyModule()
real_result = module(
torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
)
module = MyModule()
graph, _ = torch._dynamo.export(
module, torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]])
)
result = graph(
torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]])
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
result = graph(
torch.tensor([[1, 0], [0.25, 0.25]]),
torch.tensor([[0.33, 0.33], [0.25, 0.25]]),
)
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_tuple_nonzero(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return (torch.nonzero(x), torch.nonzero(x))
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_list_nonzero(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return [torch.nonzero(x), torch.nonzero(x)]
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_list_nonzero_free_function(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
return [torch.nonzero(x), torch.nonzero(x)]
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
elements = helper_fn(x)
all_y = []
for element in elements:
for item in element:
all_y.append(y * item)
return all_y
module = MyModule()
real_result = module(torch.tensor([1.0, 1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([1.0, 1.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_dict_values(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return {"x": x, "x^2": x * x}
def forward(self, x):
y = torch.tensor([0.5])
elements = self.helper_fn(x)
y = y * elements["x"]
y = y * elements["x^2"]
return y
module = MyModule()
real_result = module(torch.tensor([2.0, 2.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([2.0, 2.0]))
# Tensor input can be almost anything here, and the result will capture what we
# made constant at compile time.
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_none_control_flow(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([-1]))
# X is negative, so .item() < 0, which means we return y
self.assertEqual(real_result, torch.tensor([0.5]))
graph, guards = torch._dynamo.export(module, torch.tensor([-1]))
result = graph(torch.tensor([2]))
# X is positive, but we compiled helper_fn to return None, so it will still return y
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_not_none_control_flow(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_none_control_flow_free_func(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
if x.item() < 0:
return None
else:
return x
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
x = helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([-1]))
# X is negative, so .item() < 0, which means we return y
self.assertEqual(real_result, torch.tensor([0.5]))
graph, guards = torch._dynamo.export(module, torch.tensor([-1]))
result = graph(torch.tensor([2]))
# X is positive, but we compiled helper_fn to return None, so it will still return y
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_not_none_control_flow_pos(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
if x.item() < 0:
return None
else:
return x
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_not_none_control_flow_free_func(self):
@torch._dynamo.assume_constant_result
def helper_fn(x):
if x.item() < 0:
return None
else:
return x
class MyModule(torch.nn.Module):
def forward(self, x):
y = torch.tensor([0.5])
x = helper_fn(x)
if x is None:
return y
return y * x
module = MyModule()
real_result = module(torch.tensor([2]))
# X is positive, so .item() > 0, which means we return y * x
self.assertEqual(real_result, torch.tensor([1.0]))
graph, guards = torch._dynamo.export(module, torch.tensor([2]))
result = graph(torch.tensor([-0.5]))
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
self.assertTrue(torch._dynamo.utils.same(result, real_result))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
def test_export_with_constant_not_return_const(self):
class MyModule(torch.nn.Module):
@torch._dynamo.assume_constant_result
def helper_fn(self, x):
return self.val
def forward(self, x):
y = torch.tensor([0.5])
x = self.helper_fn(x)
if x == "A":
return y
return -1
module = MyModule()
module.val = "A"
resA = module(torch.tensor([2]))
graph, guards = torch._dynamo.export(module, torch.tensor([2]))
module.val = "B"
resB = graph(torch.tensor([2]))
self.assertTrue(torch._dynamo.utils.same(resA, resB))
def test_export_decomp(self):
def f(x):
return x.t() + x.t()
def nop(x):
return x.cos()
graph, _ = torch._dynamo.export(
f,
(torch.randn(5)),
aten_graph=True,
decomposition_table={torch.ops.aten.t.default: nop},
)
self.assertEqual(
len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
0,
)
graph, _ = torch._dynamo.export(
f, (torch.randn(5)), aten_graph=True, decomposition_table=None
)
self.assertEqual(
len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
2,
)
def test_export_decomp_asserts_bad_args(self):
def f(x):
return x.t() + x.t()
def nop(x):
return x.cos()
with self.assertRaises(AssertionError):
graph, _ = torch._dynamo.export(
f,
(torch.randn(5)),
aten_graph=False,
decomposition_table={torch.ops.aten.t.default: nop},
)
def test_export_decomp_asserts_bad_args_mode(self):
def f(x):
return x.t() + x.t()
def nop(x):
return x.cos()
with self.assertRaises(AssertionError):
graph, _ = torch._dynamo.export(
f, (torch.randn(5)), aten_graph=False, tracing_mode="symbolic"
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()