mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit bc934f57d7c14b07e7497eb72a90d893270bc662. Reverted https://github.com/pytorch/pytorch/pull/151407 on behalf of https://github.com/izaitsevfb due to breaks export tests ([comment](https://github.com/pytorch/pytorch/pull/151407#issuecomment-2810716135))
4650 lines
153 KiB
Python
4650 lines
153 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
"""
|
|
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
|
with test_export_persist_assert)
|
|
"""
|
|
|
|
import copy
|
|
import functools
|
|
import inspect
|
|
import io
|
|
import operator
|
|
import unittest
|
|
from collections.abc import Sequence
|
|
from enum import Enum
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from functorch.experimental.control_flow import cond
|
|
from torch._dynamo import config
|
|
from torch._dynamo.exc import UserError
|
|
from torch._dynamo.testing import normalize_gm
|
|
from torch._higher_order_ops.out_dtype import out_dtype
|
|
from torch._subclasses import fake_tensor
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
ConstraintViolationError,
|
|
DimDynamic,
|
|
ShapeEnv,
|
|
StatelessSymbolicContext,
|
|
)
|
|
from torch.testing._internal import common_utils
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
|
|
|
|
@torch._dynamo.assume_constant_result
|
|
def dynamo_assume_constant_result_global_function():
|
|
return "test"
|
|
|
|
|
|
class ExportTests(torch._dynamo.test_case.TestCase):
|
|
# TODO(voz): Refactor to a shared test function.
|
|
# The tests in this file are a little redundant,
|
|
# They all take a func, run it with eager, then export it, then compare
|
|
def test_export(self):
|
|
def pre_attention_state_ops(input, mems, state):
|
|
lc_key = state[0]
|
|
lc_val = state[1]
|
|
bar = []
|
|
for _ in range(0, 4):
|
|
bar2 = []
|
|
for _ in range(0, 3):
|
|
bar2.append(
|
|
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
|
|
)
|
|
bar.append(bar2)
|
|
|
|
return bar
|
|
|
|
def func():
|
|
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
|
|
state = [
|
|
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
|
|
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
|
|
]
|
|
i = torch.tensor(
|
|
[
|
|
[0.0313, -0.1487, -0.3846, -0.5321],
|
|
[-1.7073, 1.3331, -0.0890, -1.4935],
|
|
[-0.8314, -0.1862, -0.5935, 1.5232],
|
|
]
|
|
)
|
|
return pre_attention_state_ops(i, mems, state)
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func()
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)()
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph()
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_no_tensor_computation_fail(self):
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Failed to produce a graph",
|
|
):
|
|
inp = [torch.randn(3)]
|
|
inp2 = 2
|
|
inps = [inp, inp2]
|
|
|
|
def func(x, y):
|
|
return x
|
|
|
|
torch._dynamo.export(func, same_signature=False)(*inps)
|
|
|
|
def test_no_tensor_computation(self):
|
|
inp = [torch.randn(3)]
|
|
inp2 = 2
|
|
inps = [inp, inp2]
|
|
|
|
def func(x, y):
|
|
return x
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
self.assertExpectedInline(
|
|
out_graph.code.strip(),
|
|
"""\
|
|
def forward(self, x, y):
|
|
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
|
|
x = arg0
|
|
return pytree.tree_unflatten([x], self._out_spec)""",
|
|
)
|
|
|
|
def test_no_tensor_computation_2(self):
|
|
inp = torch.randn(3)
|
|
inp2 = 2
|
|
inps = [inp, inp2]
|
|
|
|
def func(x, y):
|
|
return y
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
self.assertExpectedInline(
|
|
out_graph.code.strip(),
|
|
"""\
|
|
def forward(self, x, y):
|
|
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
|
|
x = arg0
|
|
return pytree.tree_unflatten([2], self._out_spec)""",
|
|
)
|
|
|
|
def test_export_mismatched_out(self):
|
|
def func(x):
|
|
y = x + 1
|
|
return ([x, x], (y, y))
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]]))
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_shape_control_flow_1(self):
|
|
def func(x):
|
|
if x.shape[0] > 10:
|
|
return x.cos()
|
|
return x.sin()
|
|
|
|
opt_func = torch.compile(func, backend="eager")
|
|
real_result = opt_func(torch.ones(6, 4))
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(torch.ones(6, 4))
|
|
out_graph, out_guards = exported
|
|
|
|
dynamo_result = out_graph(torch.ones(6, 4))
|
|
|
|
from torch._guards import GuardSource
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
hit = False
|
|
for guard in out_guards:
|
|
if guard.source == GuardSource.SHAPE_ENV:
|
|
hit = True
|
|
self.assertExpectedInline(
|
|
guard.code_list,
|
|
"""["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] and L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""", # noqa: B950
|
|
)
|
|
break
|
|
|
|
self.assertTrue(hit)
|
|
|
|
def test_export_control_flow_with_getattr(self):
|
|
class Animal(Enum):
|
|
COW = "moo"
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self, a):
|
|
super().__init__()
|
|
self.a = a
|
|
|
|
def forward(self, x):
|
|
if self.a == Animal.COW.value:
|
|
return x * x
|
|
else:
|
|
raise ValueError("bad")
|
|
|
|
module = MyModule("moo")
|
|
input = (torch.ones(4, 3),)
|
|
resA = module(*input)
|
|
graph, _ = torch._dynamo.export(module)(*input)
|
|
resB = graph(*input)
|
|
self.assertTrue(torch._dynamo.utils.same(resA, resB))
|
|
|
|
def test_export_graph_bypass(self):
|
|
inp = [
|
|
torch.tensor([0.1, 0.1]),
|
|
torch.tensor([0.2, 0.2]),
|
|
torch.tensor([0.3, 0.3]),
|
|
]
|
|
|
|
def func(x):
|
|
first = x[2]
|
|
second = x[2]
|
|
return first * second
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_list_unpack(self):
|
|
inp = [
|
|
torch.tensor([0.1, 0.1]),
|
|
torch.tensor([0.2, 0.2]),
|
|
torch.tensor([0.3, 0.3]),
|
|
]
|
|
|
|
def func(x):
|
|
first = x[2]
|
|
second = x[2]
|
|
return x[0], first * second, x[1], x[2]
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_with_shallow_list_copy_wo_side_effects(self):
|
|
def f(x):
|
|
y = x.copy()
|
|
return y[0] + y[1]
|
|
|
|
inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
|
|
gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
|
|
inp
|
|
).graph_module
|
|
self.assertTrue(torch._dynamo.utils.same(gm(inp), f(inp)))
|
|
|
|
def test_export_with_shallow_list_copy_with_side_effects(self):
|
|
def f(x):
|
|
y = x.copy()
|
|
x[0] = x[1]
|
|
y.append(torch.tensor([[100]]))
|
|
return x[0] + x[1], y[0] + y[1], y[2]
|
|
|
|
inp = [torch.tensor([1.3, 3.77, 0.1]), torch.tensor([8.7, 6.23, 9.9])]
|
|
gm = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
|
|
inp
|
|
).graph_module
|
|
res = gm(inp)
|
|
ref = f(inp)
|
|
self.assertTrue(torch._dynamo.utils.same(res, ref))
|
|
self.assertEqual(res[0], res[1])
|
|
|
|
def test_export_mismatched_out_2(self):
|
|
def func(x):
|
|
y = x + 1
|
|
return ([x, x], (y, y))
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(torch.tensor([[[1.3737, 0.1]]]))
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_graph_with_list(self):
|
|
inp = [
|
|
torch.tensor([0.1, 0.1]),
|
|
torch.tensor([0.2, 0.2]),
|
|
torch.tensor([0.3, 0.3]),
|
|
torch.tensor([0.4, 0.4]),
|
|
]
|
|
|
|
def func(x):
|
|
first = x[2]
|
|
second = x[2]
|
|
return first * second, x
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_graph_with_complex_reorder(self):
|
|
inp = [
|
|
torch.tensor([0.1, 0.1]),
|
|
torch.tensor([0.2, 0.2]),
|
|
torch.tensor([0.3, 0.3]),
|
|
torch.tensor([0.4, 0.4]),
|
|
]
|
|
|
|
def func(x):
|
|
first = x[0]
|
|
second = x[1]
|
|
third = x[2]
|
|
return third, first, second, first * second, first * third
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
|
|
def func(x):
|
|
y = x + 1
|
|
return y, y
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_2(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
|
|
def func(x):
|
|
y = x + 1
|
|
return y, y
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_and_bypass(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
inp2 = torch.tensor([0.4, 0.4])
|
|
inps = [inp, inp2]
|
|
|
|
def func(x, z):
|
|
y = x + 1
|
|
return y, y, z
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_and_bypass_with_non_tensor_arg(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
inp2 = torch.tensor([0.1, 0.1])
|
|
inp3 = 4
|
|
inps = [inp, inp2, inp3]
|
|
|
|
def func(x, z, k):
|
|
y = x + k
|
|
return y, y, z
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_and_bypass_reorder_with_non_tensor_arg(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
inp2 = torch.tensor([0.1, 0.1])
|
|
inp3 = 4
|
|
inps = [inp, inp2, inp3]
|
|
|
|
def func(x, z, k):
|
|
y = x + k
|
|
return z, y, y
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_dupes_and_bypass_with_non_tensor_output(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
inp2 = torch.tensor([0.1, 0.1])
|
|
inp3 = 4
|
|
inps = [inp, inp2, inp3]
|
|
|
|
def func(x, z, k):
|
|
y = x + k
|
|
return y[0].item(), y, z
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_zeroes_in_and_out_different_shape_on_test(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
return [[a], [b, c], [a + b], [[c + c]]]
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_zeroes_in_new_shape_scalar_out(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
return a[0].item() + b[0].item() + c[0].item()
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_zeroes_in_new_shape_scalar_out_permute(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
return b[0].item() + c[0].item() + a[0].item() + a[0].item()
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
return a, b[0].item() + c[0].item() + a[0].item() + a[0].item(), a
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_func_return(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
x = a + b + c
|
|
|
|
def func2(y):
|
|
return x * y
|
|
|
|
return func2(x)
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dict_return(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
x = a + b + c
|
|
return {"a": x}
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_with_aten_graph(self):
|
|
def pre_attention_state_ops(input, mems, state):
|
|
lc_key = state[0]
|
|
lc_val = state[1]
|
|
bar = []
|
|
for _ in range(0, 4):
|
|
bar2 = []
|
|
for _ in range(0, 3):
|
|
bar2.append(
|
|
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
|
|
)
|
|
bar.append(bar2)
|
|
|
|
return bar
|
|
|
|
def func():
|
|
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
|
|
state = [
|
|
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
|
|
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
|
|
]
|
|
i = torch.tensor(
|
|
[
|
|
[0.0313, -0.1487, -0.3846, -0.5321],
|
|
[-1.7073, 1.3331, -0.0890, -1.4935],
|
|
[-0.8314, -0.1862, -0.5935, 1.5232],
|
|
]
|
|
)
|
|
return pre_attention_state_ops(i, mems, state)
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func()
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)()
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph()
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_no_tensor_computation_with_aten_graph(self):
|
|
inp = [torch.randn(3)]
|
|
inp2 = 2
|
|
inps = [inp, inp2]
|
|
|
|
def func(x, y):
|
|
return x
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
self.assertExpectedInline(
|
|
out_graph.code.strip(),
|
|
"""\
|
|
def forward(self, x, y):
|
|
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
|
|
arg0_1 = arg0
|
|
return pytree.tree_unflatten([arg0_1], self._out_spec)""",
|
|
)
|
|
|
|
def test_no_tensor_computation_2_with_aten_graph(self):
|
|
inp = torch.randn(3)
|
|
inp2 = 2
|
|
inps = [inp, inp2]
|
|
|
|
def func(x, y):
|
|
return y
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
self.assertExpectedInline(
|
|
out_graph.code.strip(),
|
|
"""\
|
|
def forward(self, x, y):
|
|
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
|
|
arg0_1 = arg0
|
|
return pytree.tree_unflatten([2], self._out_spec)""",
|
|
)
|
|
|
|
def test_export_mismatched_out_with_aten_graph(self):
|
|
def func(x):
|
|
y = x + 1
|
|
return ([x, x], (y, y))
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(
|
|
torch.tensor([[[1.3737, 0.1]]])
|
|
)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_graph_bypass_with_aten_graph(self):
|
|
inp = [
|
|
torch.tensor([0.1, 0.1]),
|
|
torch.tensor([0.2, 0.2]),
|
|
torch.tensor([0.3, 0.3]),
|
|
]
|
|
|
|
def func(x):
|
|
first = x[2]
|
|
second = x[2]
|
|
return first * second
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_list_unpack_with_aten_graph(self):
|
|
inp = [
|
|
torch.tensor([0.1, 0.1]),
|
|
torch.tensor([0.2, 0.2]),
|
|
torch.tensor([0.3, 0.3]),
|
|
]
|
|
|
|
def func(x):
|
|
first = x[2]
|
|
second = x[2]
|
|
return x[0], first * second, x[1], x[2]
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_mismatched_out_2_with_aten_graph(self):
|
|
def func(x):
|
|
y = x + 1
|
|
return ([x, x], (y, y))
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(torch.tensor([[[1.3737, 0.1]]]))
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(
|
|
torch.tensor([[[1.3737, 0.1]]])
|
|
)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(torch.tensor([[[1.3737, 0.1]]]))
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_graph_with_list_with_aten_graph(self):
|
|
inp = [
|
|
torch.tensor([0.1, 0.1]),
|
|
torch.tensor([0.2, 0.2]),
|
|
torch.tensor([0.3, 0.3]),
|
|
torch.tensor([0.4, 0.4]),
|
|
]
|
|
|
|
def func(x):
|
|
first = x[2]
|
|
second = x[2]
|
|
return first * second, x
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_graph_with_complex_reorder_with_aten_graph(self):
|
|
inp = [
|
|
torch.tensor([0.1, 0.1]),
|
|
torch.tensor([0.2, 0.2]),
|
|
torch.tensor([0.3, 0.3]),
|
|
torch.tensor([0.4, 0.4]),
|
|
]
|
|
|
|
def func(x):
|
|
first = x[0]
|
|
second = x[1]
|
|
third = x[2]
|
|
return third, first, second, first * second, first * third
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_with_aten_graph(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
|
|
def func(x):
|
|
y = x + 1
|
|
return y, y
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_2_with_aten_graph(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
|
|
def func(x):
|
|
y = x + 1
|
|
return y, y
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(inp)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_and_bypass_with_aten_graph(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
inp2 = torch.tensor([0.4, 0.4])
|
|
inps = [inp, inp2]
|
|
|
|
def func(x, z):
|
|
y = x + 1
|
|
return y, y, z
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_and_bypass_with_non_tensor_arg_with_aten_graph(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
inp2 = torch.tensor([0.1, 0.1])
|
|
inp3 = 4
|
|
inps = [inp, inp2, inp3]
|
|
|
|
def func(x, z, k):
|
|
y = x + k
|
|
return y, y, z
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dupes_and_bypass_reorder_with_non_tensor_arg_with_aten_graph(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
inp2 = torch.tensor([0.1, 0.1])
|
|
inp3 = 4
|
|
inps = [inp, inp2, inp3]
|
|
|
|
def func(x, z, k):
|
|
y = x + k
|
|
return z, y, y
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_dupes_and_bypass_with_non_tensor_output_with_aten_graph(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
inp2 = torch.tensor([0.1, 0.1])
|
|
inp3 = 4
|
|
inps = [inp, inp2, inp3]
|
|
|
|
def func(x, z, k):
|
|
y = x + k
|
|
return y[0].item(), y, z
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_zeroes_in_and_out_different_shape_on_test_with_aten_graph(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
return [[a], [b, c], [a + b], [[c + c]]]
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_func_return_with_aten_graph(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
x = a + b + c
|
|
|
|
def func2(y):
|
|
return x * y
|
|
|
|
return func2(x)
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_dict_return_with_aten_graph(self):
|
|
inp = torch.zeros(10)
|
|
inp2 = torch.zeros(10)
|
|
inp3 = torch.zeros(10)
|
|
inps = [inp, inp2, inp3]
|
|
|
|
inps_rand = [torch.randn(10), torch.randn(10), torch.randn(10)]
|
|
|
|
def func(a, b, c):
|
|
x = a + b + c
|
|
return {"a": x}
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps_rand)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps_rand)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_with_stack_trace(self):
|
|
inp = torch.randn(4, 4)
|
|
|
|
class MyBlock(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = torch.nn.functional.linear(x, torch.randn(4, 4))
|
|
return torch.cos(x).relu() + 1
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.block = MyBlock()
|
|
|
|
def forward(self, x):
|
|
out = self.block(x)
|
|
return out
|
|
|
|
exported = torch._dynamo.export(MyModule(), aten_graph=False)(inp)
|
|
out_graph = exported[0]
|
|
|
|
for node in out_graph.graph.nodes:
|
|
if node.op not in {"placeholder", "output"}:
|
|
self.assertTrue(node.stack_trace is not None)
|
|
self.assertTrue(node.meta["nn_module_stack"] is not None)
|
|
self.assertTrue(node.meta["source_fn_stack"] is not None)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(MyModule(), aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
for node in out_graph.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertTrue(node.stack_trace is not None)
|
|
self.assertTrue(node.meta["nn_module_stack"] is not None)
|
|
self.assertTrue(node.meta["source_fn_stack"] is not None)
|
|
self.assertTrue(node.meta["val"] is not None)
|
|
self.assertTrue(node.meta["original_aten"] is not None)
|
|
|
|
def test_export_preserves_nn_module_stack_for_get_attr(self):
|
|
inp = torch.randn(4, 4)
|
|
|
|
class MyBlock(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.ones(1, 1))
|
|
self.buffer = torch.nn.Buffer(torch.ones(1, 1))
|
|
|
|
def forward(self, x):
|
|
x = torch.nn.functional.linear(x, torch.randn(4, 4))
|
|
return torch.cos(x).relu() + self.weight + self.buffer
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.block = MyBlock()
|
|
|
|
def forward(self, x):
|
|
out = self.block(x)
|
|
return out
|
|
|
|
m = MyModule()
|
|
exported = torch._dynamo.export(m, aten_graph=False)(inp)
|
|
out_graph = exported[0]
|
|
|
|
attr_access_count = 0
|
|
for node in out_graph.graph.nodes:
|
|
if node.op == "get_attr":
|
|
attr_access_count += 1
|
|
self.assertTrue(node.meta["nn_module_stack"] is not None)
|
|
self.assertEqual(attr_access_count, 2)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(m, aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
|
|
attr_access_count = 0
|
|
for node in out_graph.graph.nodes:
|
|
if node.op == "get_attr":
|
|
attr_access_count += 1
|
|
self.assertTrue(node.meta["nn_module_stack"] is not None)
|
|
self.assertEqual(attr_access_count, 2)
|
|
|
|
def test_export_compare_optimize_with_make_fx(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
linear = torch.nn.Linear(2, 2)
|
|
|
|
def func(x):
|
|
x = x + 1
|
|
y = x.t()
|
|
y = y.relu()
|
|
y = linear(y)
|
|
return y
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(inp)
|
|
out_graph = exported[0]
|
|
export_result = out_graph(inp)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
def compiler(gm, sample_inputs):
|
|
def fw(*args):
|
|
aten_gm = make_fx(gm)(*args)
|
|
return aten_gm(*args)
|
|
|
|
return fw
|
|
|
|
opt_func = torch.compile(func, backend=compiler, fullgraph=True, dynamic=True)
|
|
make_fx_result_through_backend = opt_func(inp)
|
|
|
|
fx_g = make_fx(func)(inp)
|
|
make_fx_result_through_direct = fx_g(inp)
|
|
|
|
self.assertTrue(
|
|
torch._dynamo.utils.same(make_fx_result_through_backend, export_result)
|
|
)
|
|
self.assertTrue(
|
|
torch._dynamo.utils.same(make_fx_result_through_direct, export_result)
|
|
)
|
|
|
|
def test_export_with_constant_method_on_module(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 2))
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
return torch.nonzero(x)
|
|
|
|
def forward(self, x):
|
|
y = torch.sin(x)
|
|
x = self.linear(x)
|
|
y = self.helper_fn(x)
|
|
return y
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
|
|
module = MyModule()
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
|
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_method_on_module_invoke_twice(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 2))
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
return torch.nonzero(x)
|
|
|
|
def forward(self, x):
|
|
y = torch.sin(x)
|
|
x = self.linear(x)
|
|
y = self.helper_fn(x) + self.helper_fn(x)
|
|
return y
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
|
|
module = MyModule()
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
|
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_free_function(self):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(x):
|
|
return torch.nonzero(x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 2))
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
return torch.nonzero(x)
|
|
|
|
def forward(self, x):
|
|
y = torch.sin(x)
|
|
x = self.linear(x)
|
|
y = helper_fn(x) + self.helper_fn(x)
|
|
return y
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
|
|
module = MyModule()
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
|
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_global_function(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self):
|
|
a = dynamo_assume_constant_result_global_function()
|
|
b = dynamo_assume_constant_result_global_function()
|
|
return a + b
|
|
|
|
module = MyModule()
|
|
graph, _ = torch._dynamo.export(module)()
|
|
result = graph()
|
|
self.assertEqual(result, "testtest")
|
|
|
|
def test_export_with_constant_free_function_and_class_method(self):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(x):
|
|
return torch.nonzero(x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 2))
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
y = torch.sin(x)
|
|
x = self.linear(x)
|
|
y = helper_fn(x)
|
|
return y
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([[1.0, 0], [0, 0]]))
|
|
module = MyModule()
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
|
|
result = graph(torch.tensor([[1.0, 0.0], [0, 0]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
result = graph(torch.tensor([[1, 0], [0.25, 0.25]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_free_function_and_class_method_multiarg(self):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(x):
|
|
return torch.nonzero(x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 2))
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x, z):
|
|
y = torch.sin(x)
|
|
x = self.linear(x)
|
|
y = helper_fn(x) + helper_fn(z)
|
|
return y
|
|
|
|
module = MyModule()
|
|
real_result = module(
|
|
torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
|
|
)
|
|
module = MyModule()
|
|
graph, _ = torch._dynamo.export(module)(
|
|
torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
|
|
)
|
|
result = graph(
|
|
torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[1.0, 0.0], [0, 0]])
|
|
)
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
result = graph(
|
|
torch.tensor([[1, 0], [0.25, 0.25]]), torch.tensor([[1, 0], [0.25, 0.25]])
|
|
)
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_free_function_and_class_method_multiarg_diff(self):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(x):
|
|
return torch.nonzero(x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x, z):
|
|
y = helper_fn(x) + helper_fn(z)
|
|
return y
|
|
|
|
module = MyModule()
|
|
real_result = module(
|
|
torch.tensor([[1.0, 0], [0, 0]]), torch.tensor([[1.0, 0], [0, 0]])
|
|
)
|
|
module = MyModule()
|
|
graph, _ = torch._dynamo.export(module)(
|
|
torch.tensor([[0.0, 0], [0, 0]]), torch.tensor([[0.0, 0], [0.5, 0]])
|
|
)
|
|
result = graph(
|
|
torch.tensor([[1.0, 0.0], [0, 0]]), torch.tensor([[0.0, 1.0], [0, 0]])
|
|
)
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
result = graph(
|
|
torch.tensor([[1, 0], [0.25, 0.25]]),
|
|
torch.tensor([[0.33, 0.33], [0.25, 0.25]]),
|
|
)
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_tuple_nonzero(self):
|
|
class MyModule(torch.nn.Module):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
return (torch.nonzero(x), torch.nonzero(x))
|
|
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
elements = self.helper_fn(x)
|
|
all_y = []
|
|
for element in elements:
|
|
for item in element:
|
|
all_y.append(y * item)
|
|
return all_y
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([1.0, 1.0]))
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
|
|
|
|
# Tensor input can be almost anything here, and the result will capture what we
|
|
# made constant at compile time.
|
|
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_list_nonzero(self):
|
|
class MyModule(torch.nn.Module):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
return [torch.nonzero(x), torch.nonzero(x)]
|
|
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
elements = self.helper_fn(x)
|
|
all_y = []
|
|
for element in elements:
|
|
for item in element:
|
|
all_y.append(y * item)
|
|
return all_y
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([1.0, 1.0]))
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
|
|
|
|
# Tensor input can be almost anything here, and the result will capture what we
|
|
# made constant at compile time.
|
|
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_list_nonzero_free_function(self):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(x):
|
|
return [torch.nonzero(x), torch.nonzero(x)]
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
elements = helper_fn(x)
|
|
all_y = []
|
|
for element in elements:
|
|
for item in element:
|
|
all_y.append(y * item)
|
|
return all_y
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([1.0, 1.0]))
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([1.0, 1.0]))
|
|
|
|
# Tensor input can be almost anything here, and the result will capture what we
|
|
# made constant at compile time.
|
|
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_dict_values(self):
|
|
class MyModule(torch.nn.Module):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
return {"x": x, "x^2": x * x}
|
|
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
elements = self.helper_fn(x)
|
|
y = y * elements["x"]
|
|
y = y * elements["x^2"]
|
|
return y
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([2.0, 2.0]))
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([2.0, 2.0]))
|
|
|
|
# Tensor input can be almost anything here, and the result will capture what we
|
|
# made constant at compile time.
|
|
result = graph(torch.tensor([[[1.0, 0], [0, 0]], [[1.0, 0], [0, 0]]]))
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_none_control_flow(self):
|
|
class MyModule(torch.nn.Module):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
if x.item() < 0:
|
|
return None
|
|
else:
|
|
return x
|
|
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
x = self.helper_fn(x)
|
|
if x is None:
|
|
return y
|
|
return y * x
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([-1]))
|
|
|
|
# X is negative, so .item() < 0, which means we return y
|
|
self.assertEqual(real_result, torch.tensor([0.5]))
|
|
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([-1]))
|
|
result = graph(torch.tensor([2]))
|
|
# X is positive, but we compiled helper_fn to return None, so it will still return y
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_not_none_control_flow(self):
|
|
class MyModule(torch.nn.Module):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
if x.item() < 0:
|
|
return None
|
|
else:
|
|
return x
|
|
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
x = self.helper_fn(x)
|
|
if x is None:
|
|
return y
|
|
return y * x
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([2]))
|
|
|
|
# X is positive, so .item() > 0, which means we return y * x
|
|
self.assertEqual(real_result, torch.tensor([1.0]))
|
|
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([2]))
|
|
result = graph(torch.tensor([-0.5]))
|
|
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_none_control_flow_free_func(self):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(x):
|
|
if x.item() < 0:
|
|
return None
|
|
else:
|
|
return x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
x = helper_fn(x)
|
|
if x is None:
|
|
return y
|
|
return y * x
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([-1]))
|
|
|
|
# X is negative, so .item() < 0, which means we return y
|
|
self.assertEqual(real_result, torch.tensor([0.5]))
|
|
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([-1]))
|
|
result = graph(torch.tensor([2]))
|
|
# X is positive, but we compiled helper_fn to return None, so it will still return y
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_not_none_control_flow_pos(self):
|
|
class MyModule(torch.nn.Module):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
if x.item() < 0:
|
|
return None
|
|
else:
|
|
return x
|
|
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
x = self.helper_fn(x)
|
|
if x is None:
|
|
return y
|
|
return y * x
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([2]))
|
|
|
|
# X is positive, so .item() > 0, which means we return y * x
|
|
self.assertEqual(real_result, torch.tensor([1.0]))
|
|
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([2]))
|
|
result = graph(torch.tensor([-0.5]))
|
|
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_not_none_control_flow_free_func(self):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(x):
|
|
if x.item() < 0:
|
|
return None
|
|
else:
|
|
return x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
x = helper_fn(x)
|
|
if x is None:
|
|
return y
|
|
return y * x
|
|
|
|
module = MyModule()
|
|
real_result = module(torch.tensor([2]))
|
|
|
|
# X is positive, so .item() > 0, which means we return y * x
|
|
self.assertEqual(real_result, torch.tensor([1.0]))
|
|
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([2]))
|
|
result = graph(torch.tensor([-0.5]))
|
|
# X is negative, but we compiled helper_fn to return x, so it will still return y * x
|
|
self.assertTrue(torch._dynamo.utils.same(result, real_result))
|
|
|
|
def test_export_with_constant_not_return_const(self):
|
|
class MyModule(torch.nn.Module):
|
|
@torch._dynamo.assume_constant_result
|
|
def helper_fn(self, x):
|
|
return self.val
|
|
|
|
def forward(self, x):
|
|
y = torch.tensor([0.5])
|
|
x = self.helper_fn(x)
|
|
if x == "A":
|
|
return y
|
|
return -1
|
|
|
|
module = MyModule()
|
|
module.val = "A"
|
|
resA = module(torch.tensor([2]))
|
|
graph, _ = torch._dynamo.export(module)(torch.tensor([2]))
|
|
module.val = "B"
|
|
resB = graph(torch.tensor([2]))
|
|
self.assertTrue(torch._dynamo.utils.same(resA, resB))
|
|
|
|
def test_export_with_builtin_op_on_assume_constant(self):
|
|
@torch._dynamo.assume_constant_result
|
|
def get_y(y) -> torch.Tensor:
|
|
return y
|
|
|
|
class Bob(torch.nn.Module):
|
|
def __init__(self, p, val) -> None:
|
|
super().__init__()
|
|
self.p = p
|
|
self.y = torch.nn.Parameter(torch.tensor(val))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# This only looks dynamic but it's actually a constant value
|
|
if get_y(self.y) < self.p:
|
|
return torch.cat([x, x])
|
|
else:
|
|
return x
|
|
|
|
model = Bob(0.5, 0.3)
|
|
inp = torch.ones(3, 4)
|
|
graph, _ = torch._dynamo.export(model)(inp)
|
|
self.assertEqual(model(inp), graph(inp))
|
|
|
|
def test_export_with_constant_in_unspecialized_nn_module(self):
|
|
class Module(torch.nn.Module):
|
|
def __init__(self, y):
|
|
super().__init__()
|
|
self.y = y
|
|
|
|
@torch._dynamo.assume_constant_result
|
|
def check(self):
|
|
return self.y[0].item() == 1
|
|
|
|
def forward(self, x):
|
|
# This line leads to module obj being tracked as UnspecializedNNModuleVariable in dynamo
|
|
self.device = x.device
|
|
|
|
if self.check():
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
model = Module(torch.tensor([1]))
|
|
inp = torch.ones(3, 4)
|
|
graph, _ = torch._dynamo.export(model)(inp)
|
|
self.assertEqual(model(inp), graph(inp))
|
|
|
|
def test_export_decomp(self):
|
|
def f(x):
|
|
return x.t() + x.t()
|
|
|
|
def nop(x):
|
|
return x.cos()
|
|
|
|
graph, _ = torch._dynamo.export(
|
|
f,
|
|
aten_graph=True,
|
|
decomposition_table={torch.ops.aten.t.default: nop},
|
|
)(torch.randn(5))
|
|
self.assertEqual(
|
|
len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
|
|
0,
|
|
)
|
|
|
|
graph, _ = torch._dynamo.export(f, aten_graph=True, decomposition_table=None)(
|
|
torch.randn(5)
|
|
)
|
|
self.assertEqual(
|
|
len([n for n in graph.graph.nodes if n.target == torch.ops.aten.t.default]),
|
|
2,
|
|
)
|
|
|
|
def test_export_decomp_asserts_bad_args(self):
|
|
def f(x):
|
|
return x.t() + x.t()
|
|
|
|
def nop(x):
|
|
return x.cos()
|
|
|
|
with self.assertRaises(AssertionError):
|
|
torch._dynamo.export(
|
|
f,
|
|
(torch.randn(5)),
|
|
aten_graph=False,
|
|
decomposition_table={torch.ops.aten.t.default: nop},
|
|
)
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_export_with_module_layer(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, pred, x):
|
|
def true_fn(val):
|
|
return self.linear(val) * torch.tensor(2)
|
|
|
|
def false_fn(val):
|
|
return self.linear(val) * torch.tensor(-1)
|
|
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
mod = Module()
|
|
x = torch.randn([3, 3])
|
|
pred = torch.tensor(x[0][0].item() < 0)
|
|
real_result = mod.forward(pred, x)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(mod.forward)(pred, x)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(pred, x)
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
# New X, just to show we did not specialize
|
|
x = x * -1
|
|
pred = torch.tensor(x[0][0].item() < 0)
|
|
real_result_2 = mod.forward(pred, x)
|
|
dynamo_result_2 = out_graph(pred, x)
|
|
self.assertTrue(torch._dynamo.utils.same(real_result_2, dynamo_result_2))
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_export_with_cond_branches_calling_methods(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
class Module(torch.nn.Module):
|
|
# ok
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def t(self, val):
|
|
return val + 1
|
|
|
|
def f(self, val):
|
|
return val - 1
|
|
|
|
def true_fn(self, val):
|
|
return self.linear(val) + self.t(val)
|
|
|
|
def false_fn(self, val):
|
|
return self.linear(val) - self.f(val)
|
|
|
|
def forward(self, pred, x):
|
|
return cond(pred, self.true_fn, self.false_fn, [x])
|
|
|
|
mod = Module()
|
|
x = torch.randn([3, 3])
|
|
pred = torch.tensor(x[0][0].item() < 0)
|
|
real_result = mod.forward(pred, x)
|
|
out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
|
|
dynamo_result = out_graph(pred, x)
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_export_with_cond_closure(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, pred, x):
|
|
def true_fn(x):
|
|
return x * 2
|
|
|
|
def false_fn(x):
|
|
return x - 2
|
|
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
class Bar(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, pred, x):
|
|
def true_fn(x):
|
|
return x * 2
|
|
|
|
def false_fn(x):
|
|
return x - 2
|
|
|
|
return cond(pred, true_fn, false_fn, [x + 1])
|
|
|
|
class FooBar(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 3)
|
|
|
|
def forward(self, pred, x):
|
|
y = x + x
|
|
|
|
def true_fn(x, y):
|
|
return self.linear(x) * (x + y)
|
|
|
|
def false_fn(x, y):
|
|
return x * (y - x)
|
|
|
|
return cond(pred, true_fn, false_fn, [x, y])
|
|
|
|
for Module in [Foo, Bar, FooBar]:
|
|
mod = Module()
|
|
x = torch.randn([3, 3], requires_grad=True)
|
|
pred = torch.tensor(x[0][0].item() < 0)
|
|
real_result = mod.forward(pred, x)
|
|
out_graph, _ = torch._dynamo.export(mod.forward)(pred, x)
|
|
dynamo_result = out_graph(pred, x)
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_with_cond_with_closed_function(self):
|
|
def hello(x):
|
|
return x + 1
|
|
|
|
def hi(x):
|
|
return x + 2
|
|
|
|
def foo(pred, x):
|
|
def true_fn(x):
|
|
return hello(x)
|
|
|
|
def false_fn(x):
|
|
return hi(x)
|
|
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
x = torch.randn(5)
|
|
pred = x[0] > 0
|
|
real_result = foo(pred, x)
|
|
out_graph, _ = torch._dynamo.export(foo)(pred, x)
|
|
dynamo_result = out_graph(pred, x)
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_with_cond_dynamic_shape_pred(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
def true_fn(x):
|
|
return x + x
|
|
|
|
def false_fn(x):
|
|
return x[:2]
|
|
|
|
return cond(x.shape[0] <= 2, true_fn, false_fn, [x])
|
|
|
|
class Module2(torch.nn.Module):
|
|
def forward(self, x):
|
|
def true_fn(x):
|
|
return x + x
|
|
|
|
def false_fn(x):
|
|
return x[:2]
|
|
|
|
return cond(x.shape[0] <= 2, true_fn, false_fn, (x,))
|
|
|
|
mods = [Module(), Module2()]
|
|
for mod in mods:
|
|
x = torch.randn(2, 2)
|
|
out_graph, _ = torch._dynamo.export(mod)(x)
|
|
self.assertExpectedInline(
|
|
out_graph.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
sym_size_int = torch.ops.aten.sym_size.int(l_x_, 0)
|
|
le = sym_size_int <= 2; sym_size_int = None
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, (l_x_,)); le = cond_true_0 = cond_false_0 = l_x_ = None
|
|
getitem_3 = cond[0]
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(getitem_3, 0); getitem_3 = None
|
|
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(sym_size_int_1); sym_constrain_range_for_size_default = None
|
|
ge = sym_size_int_1 >= 2; sym_size_int_1 = None
|
|
_assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 2 on node 'ge'"); ge = _assert_scalar_default = None
|
|
getitem_2 = cond[0]; cond = None
|
|
return pytree.tree_unflatten([getitem_2], self._out_spec)""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
out_graph.cond_true_0.code.strip(),
|
|
"""\
|
|
def forward(self, l_x_):
|
|
l_x__1 = l_x_
|
|
add = l_x__1 + l_x__1; l_x__1 = None
|
|
return (add,)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
out_graph.cond_false_0.code.strip(),
|
|
"""\
|
|
def forward(self, l_x_):
|
|
l_x__1 = l_x_
|
|
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
|
|
return (getitem,)""",
|
|
)
|
|
# We could successfully export branches that return different sizes
|
|
torch._dynamo.export(mod)(torch.randn(3, 2))
|
|
|
|
# We specialize into one of the branches since predicate is a python boolean.
|
|
test_x = torch.randn(3, 2)
|
|
mod(test_x)
|
|
|
|
def test_export_with_map_cond(self):
|
|
from functorch.experimental.control_flow import cond, map
|
|
|
|
class Module(torch.nn.Module):
|
|
def inner(self, x, pred):
|
|
def true_fn(x):
|
|
return x + x
|
|
|
|
def false_fn(x):
|
|
return x * x
|
|
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
def forward(self, pred, xs):
|
|
def body(x, pred):
|
|
return self.inner(x, pred)
|
|
|
|
return map(body, xs, pred)
|
|
|
|
mod = Module()
|
|
x = torch.randn(3, 2, 1)
|
|
pred_x = torch.tensor(True)
|
|
|
|
y = torch.randn(4, 3, 2)
|
|
pred_y = torch.tensor(False)
|
|
real_result = mod(pred_y, y)
|
|
|
|
out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
|
|
self.assertEqual(real_result, out_graph(pred_y, y))
|
|
|
|
def test_export_with_map_zero_sized_tensor(self):
|
|
from functorch.experimental.control_flow import map
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, xs):
|
|
def body(x):
|
|
return x + 1
|
|
|
|
return map(body, xs)
|
|
|
|
mod = Module()
|
|
xs = torch.randn(0, 2)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"zero-sized tensor",
|
|
):
|
|
torch._dynamo.export(mod)(xs)
|
|
|
|
def test_export_meta_val(self):
|
|
def f(x, y, z):
|
|
return x * y + z
|
|
|
|
gm, _ = torch._dynamo.export(
|
|
f,
|
|
aten_graph=True,
|
|
)(
|
|
torch.ones(3, 2),
|
|
torch.zeros(3, 2),
|
|
torch.ones(3, 2),
|
|
)
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
self.assertIn("val", node.meta)
|
|
|
|
def test_input_container_type(self):
|
|
def f(x: torch.Tensor, y: list[torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
return {"a": x.sum() + sum(y).sum()}
|
|
|
|
inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp)
|
|
|
|
self.assertEqual(gm(*inp), f(*inp))
|
|
|
|
@config.patch(assume_static_by_default=False)
|
|
def test_export_symbolic_shape(self):
|
|
def f(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.empty(x.shape[0] * 2)
|
|
|
|
inp = (torch.randn(6, 5),)
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(*inp)
|
|
|
|
has_sym_size = False
|
|
for node in gm.graph.nodes:
|
|
if node.target is torch.ops.aten.sym_size.int:
|
|
has_sym_size = True
|
|
|
|
self.assertTrue(has_sym_size)
|
|
|
|
@config.patch(assume_static_by_default=False)
|
|
def test_dynamic_slicing(self):
|
|
def f(x):
|
|
return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2]
|
|
|
|
gm_aten_mode, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
|
|
|
|
inp = torch.randn(6, 7)
|
|
self.assertEqual(gm_aten_mode(inp).shape, f(inp).shape)
|
|
|
|
count = 0
|
|
# aten graph should flatten getitem calls to actual
|
|
# slice kernel call.
|
|
for node in gm_aten_mode.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.slice.Tensor
|
|
):
|
|
count += 1
|
|
|
|
self.assertEqual(count, 2)
|
|
|
|
gm_torch_mode, _ = torch._dynamo.export(f, aten_graph=False)(torch.randn(4, 5))
|
|
|
|
# In torch mode, the graph should contain 3 getitem methods
|
|
# one for x.shape[0]-2 and one for x.shape[1]-1 and one for slice
|
|
# this is because Tensor class has its' own getitem method
|
|
# which gets translated to aten.Slice later.
|
|
count = 0
|
|
for node in gm_torch_mode.graph.nodes:
|
|
if node.op == "call_function" and node.target == operator.getitem:
|
|
count += 1
|
|
|
|
self.assertEqual(count, 1)
|
|
self.assertEqual(gm_torch_mode(inp).shape, f(inp).shape)
|
|
|
|
def test_dynamic_slicing_invalid(self):
|
|
def g(x, y):
|
|
return x[y : x.shape[0]]
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Dynamic slicing with Tensor arguments",
|
|
):
|
|
torch._dynamo.export(
|
|
g,
|
|
aten_graph=True,
|
|
)(
|
|
torch.randn(4, 5),
|
|
torch.tensor(2),
|
|
)
|
|
|
|
@config.patch(capture_scalar_outputs=True)
|
|
def test_dynamic_slicing_simple(self):
|
|
def f(x):
|
|
return x[slice(None, None, None)]
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
|
|
|
|
inp = torch.randn(6, 7)
|
|
self.assertEqual(gm(inp), f(inp))
|
|
|
|
def test_pre_dispatch_simple(self):
|
|
def f(x):
|
|
y = torch.ones_like(x)
|
|
return torch.matmul(x, y)
|
|
|
|
gm, _ = torch._dynamo.export(
|
|
f,
|
|
aten_graph=True,
|
|
pre_dispatch=True,
|
|
tracing_mode="fake",
|
|
)(
|
|
torch.randn(5, 5),
|
|
)
|
|
|
|
inp = torch.randn(6, 6)
|
|
self.assertEqual(gm(inp), f(inp))
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
arg0_1 = arg0
|
|
ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False)
|
|
matmul = torch.ops.aten.matmul.default(arg0_1, ones_like); arg0_1 = ones_like = None
|
|
return pytree.tree_unflatten([matmul], self._out_spec)""",
|
|
)
|
|
|
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
|
def test_export_cond_in_aten_symbolic(self):
|
|
class ConditionOp(torch.nn.Module):
|
|
def true_fn(self, x, y):
|
|
return x * y
|
|
|
|
def false_fn(self, x, y):
|
|
return x + y
|
|
|
|
def forward(self, pred, x, y):
|
|
return cond(pred, self.true_fn, self.false_fn, [x, y])
|
|
|
|
model = ConditionOp()
|
|
inp = (
|
|
torch.tensor(False),
|
|
torch.randn(4, 4),
|
|
torch.randn(4, 4),
|
|
)
|
|
gm, _ = torch._dynamo.export(model, aten_graph=True)(*inp)
|
|
|
|
gm.print_readable()
|
|
|
|
self.assertEqual(gm(*inp), model(*inp))
|
|
|
|
def test_export_with_kwargs(self):
|
|
def fn_with_kwargs(pos0, tuple0, *myargs, mykw0=None, **mykwargs):
|
|
out = pos0
|
|
for arg in tuple0:
|
|
out *= arg
|
|
for arg in myargs:
|
|
out *= arg
|
|
out *= mykw0
|
|
out *= mykwargs["input0"] * mykwargs["input1"]
|
|
return out
|
|
|
|
mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
|
|
tuple0 = (torch.randn(4), torch.randn(4))
|
|
mykw0 = torch.randn(4)
|
|
pos0 = torch.randn(4)
|
|
myargs = [torch.randn(4), torch.randn(4)]
|
|
|
|
expected_argument_names = [
|
|
"pos0",
|
|
"tuple0",
|
|
"myargs_0",
|
|
"myargs_1",
|
|
"mykw0",
|
|
"input0",
|
|
"input1",
|
|
]
|
|
self._test_export_preserving_original_signature(
|
|
fn_with_kwargs,
|
|
expected_argument_names,
|
|
pos0,
|
|
tuple0,
|
|
*myargs,
|
|
mykw0=mykw0,
|
|
**mykwargs,
|
|
)
|
|
|
|
def test_export_with_kwargs_and_empty_args(self):
|
|
def fn_with_kwargs(mykw0=None, **mykwargs):
|
|
out = mykw0
|
|
out *= mykwargs["input0"] * mykwargs["input1"]
|
|
return out
|
|
|
|
mykwargs = {"input0": torch.randn(4), "input1": torch.randn(4)}
|
|
mykw0 = torch.randn(4)
|
|
|
|
expected_argument_names = ["mykw0"] + list(mykwargs.keys())
|
|
self._test_export_preserving_original_signature(
|
|
fn_with_kwargs, expected_argument_names, mykw0, **mykwargs
|
|
)
|
|
|
|
def test_export_with_args_and_empty_kwargs(self):
|
|
def fn_with_kwargs(pos0, tuple0, *myargs):
|
|
out = pos0
|
|
for arg in tuple0:
|
|
out *= arg
|
|
for arg in myargs:
|
|
out *= arg
|
|
return out
|
|
|
|
tuple0 = (torch.randn(4), torch.randn(4))
|
|
pos0 = torch.randn(4)
|
|
myargs = [torch.randn(4), torch.randn(4)]
|
|
|
|
expected_argument_names = ["pos0", "tuple0", "myargs_0", "myargs_1"]
|
|
self._test_export_preserving_original_signature(
|
|
fn_with_kwargs, expected_argument_names, pos0, tuple0, *myargs
|
|
)
|
|
|
|
@common_utils.parametrize(
|
|
"default_value",
|
|
[
|
|
common_utils.subtest(None, name="None"),
|
|
common_utils.subtest(42.0, name="float"),
|
|
common_utils.subtest(
|
|
# FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
|
|
torch.randn(4),
|
|
name="tensor",
|
|
decorators=[unittest.expectedFailure],
|
|
),
|
|
common_utils.subtest(
|
|
# FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
|
|
(torch.randn(4),),
|
|
name="tuple",
|
|
decorators=[unittest.expectedFailure],
|
|
),
|
|
],
|
|
)
|
|
def test_export_with_args_with_default(self, default_value):
|
|
def fn(pos0, pos1_default=default_value):
|
|
out = pos0
|
|
if pos1_default is None:
|
|
pos1_default = torch.randn(4)
|
|
if isinstance(pos1_default, tuple):
|
|
pos1_default = pos1_default[0]
|
|
out *= pos1_default
|
|
return out
|
|
|
|
pos0 = torch.randn(4)
|
|
expected_argument_names = ["pos0"]
|
|
self._test_export_preserving_original_signature(
|
|
fn, expected_argument_names, pos0
|
|
)
|
|
|
|
@common_utils.parametrize(
|
|
"default_value",
|
|
[
|
|
common_utils.subtest(None, name="None"),
|
|
common_utils.subtest(42.0, name="float"),
|
|
common_utils.subtest(
|
|
# FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
|
|
torch.randn(4),
|
|
name="tensor",
|
|
decorators=[unittest.expectedFailure],
|
|
),
|
|
common_utils.subtest(
|
|
# FIXME: AssertionError: Dynamo input and output is a strict subset of traced input/output
|
|
(torch.randn(4),),
|
|
name="tuple",
|
|
decorators=[unittest.expectedFailure],
|
|
),
|
|
],
|
|
)
|
|
def test_export_with_kwargs_with_default(self, default_value):
|
|
def fn(pos0, *, kw0, kw1_default=default_value, **kwargs):
|
|
out = pos0
|
|
out += kw0
|
|
if kw1_default is None:
|
|
kw1_default = torch.randn(4)
|
|
elif isinstance(kw1_default, tuple):
|
|
kw1_default = kw1_default[0]
|
|
out += kw1_default
|
|
out += kwargs["kw2"]
|
|
return out
|
|
|
|
pos0 = torch.randn(4)
|
|
kw0 = torch.randn(4)
|
|
kw2 = torch.randn(4)
|
|
|
|
args = (pos0,)
|
|
kwargs = {"kw0": kw0, "kw2": kw2}
|
|
expected_argument_names = ["pos0", "kw0", "kw2"]
|
|
self._test_export_preserving_original_signature(
|
|
fn, expected_argument_names, *args, **kwargs
|
|
)
|
|
|
|
def test_export_with_wrapped_fn(self):
|
|
# To ensure dynamo.export is robust to wrapped functions
|
|
# when it cannot use `inspect` to retrieve original signature
|
|
# info.
|
|
def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
|
|
out = pos0
|
|
out += pos1
|
|
out += kw0
|
|
out += kw1
|
|
for arg in args:
|
|
out += arg
|
|
for kwarg in kwargs.values():
|
|
out += kwarg
|
|
return out
|
|
|
|
def wrapped_fn(*args, **kwargs):
|
|
return _fn(*args, **kwargs)
|
|
|
|
pos0 = torch.randn(4)
|
|
kw0 = torch.randn(4)
|
|
args = (pos0, torch.randn(4), torch.randn(4))
|
|
kwargs = {"kw0": kw0, "kw2": torch.randn(4)}
|
|
expected_argument_names = [f"args_{i}" for i in range(len(args))] + list(
|
|
kwargs.keys()
|
|
)
|
|
|
|
self._test_export_preserving_original_signature(
|
|
wrapped_fn, expected_argument_names, *args, **kwargs
|
|
)
|
|
|
|
def test_export_with_functools_wrapped_method(self):
|
|
def test_decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
@test_decorator
|
|
def method_to_test(self, pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
|
|
out = pos0
|
|
out += pos1
|
|
out += kw0
|
|
out += kw1
|
|
for arg in args:
|
|
out += arg
|
|
for kwarg in kwargs.values():
|
|
out += kwarg
|
|
return out
|
|
|
|
pos0 = torch.randn(4)
|
|
pos1 = torch.randn(4)
|
|
unnamed_pos = torch.randn(4)
|
|
kw0 = torch.randn(4)
|
|
args = (pos0, pos1, unnamed_pos)
|
|
kwargs = {"kw0": kw0, "kw2": torch.randn(4), "unnamed_kw": torch.randn(4)}
|
|
expected_argument_names = [
|
|
"pos0",
|
|
"pos1",
|
|
"args_0", # 3rd unnamed positional argument
|
|
] + list(kwargs.keys())
|
|
m = MyModule()
|
|
|
|
self._test_export_preserving_original_signature(
|
|
m.method_to_test, expected_argument_names, *args, **kwargs
|
|
)
|
|
|
|
def test_export_with_functools_wrapped_fn(self):
|
|
def test_decorator(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
@test_decorator
|
|
def _fn(pos0, pos1=1.0, *args, kw0, kw1=2.0, **kwargs):
|
|
out = pos0
|
|
out += pos1
|
|
out += kw0
|
|
out += kw1
|
|
for arg in args:
|
|
out += arg
|
|
for kwarg in kwargs.values():
|
|
out += kwarg
|
|
return out
|
|
|
|
def wrapped_fn(*args, **kwargs):
|
|
return _fn(*args, **kwargs)
|
|
|
|
pos0 = torch.randn(4)
|
|
kw0 = torch.randn(4)
|
|
args = (pos0, torch.randn(4), torch.randn(4))
|
|
kwargs = {"kw0": kw0, "kw2": torch.randn(4)}
|
|
expected_argument_names = [f"args_{i}" for i in range(len(args))] + list(
|
|
kwargs.keys()
|
|
)
|
|
|
|
self._test_export_preserving_original_signature(
|
|
wrapped_fn, expected_argument_names, *args, **kwargs
|
|
)
|
|
|
|
def _test_export_preserving_original_signature(
|
|
self, fn, expected_argument_names: Sequence[str], *args, **kwargs
|
|
):
|
|
torch._dynamo.reset()
|
|
exported = torch._dynamo.export(
|
|
fn,
|
|
*args,
|
|
**kwargs,
|
|
aten_graph=False,
|
|
)
|
|
|
|
out_graph = exported[0]
|
|
dynamo_result = out_graph(*args, **kwargs)
|
|
real_result = fn(*args, **kwargs)
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
# Check that the exported graph preserves same argument names.
|
|
self.assertEqual(
|
|
inspect.getfullargspec(out_graph.forward).args[1:], expected_argument_names
|
|
)
|
|
|
|
def test_dataclass_input_output(self):
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class Tensors:
|
|
x: torch.Tensor
|
|
y: torch.Tensor
|
|
|
|
def f(t):
|
|
return t.x + t.y
|
|
|
|
with self.assertRaisesRegex(
|
|
UserError,
|
|
"It looks like one of the inputs with type .*Tensors.* "
|
|
"is not supported or pytree-flattenable",
|
|
):
|
|
torch._dynamo.export(f, aten_graph=False)(
|
|
Tensors(x=torch.randn(10), y=torch.randn(10))
|
|
)
|
|
|
|
def f(x, y):
|
|
return Tensors(x=x.sin(), y=y.cos())
|
|
|
|
with self.assertRaisesRegex(
|
|
UserError,
|
|
"It looks like one of the outputs with type .*Tensors.* "
|
|
"is not supported or pytree-flattenable",
|
|
):
|
|
torch._dynamo.export(f, aten_graph=False)(torch.randn(10), torch.randn(10))
|
|
|
|
def test_empty(self):
|
|
def f(x):
|
|
return x
|
|
|
|
exported = torch._dynamo.export(f)(torch.randn(3, 3))
|
|
out_graph = exported[0]
|
|
inp = torch.randn(3, 3)
|
|
self.assertTrue(torch._dynamo.utils.same(inp, out_graph(inp)))
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = torch.ones(3, 3)
|
|
|
|
def forward(self):
|
|
return self.a
|
|
|
|
exported = torch._dynamo.export(M())()
|
|
out_graph = exported[0]
|
|
self.assertTrue(torch._dynamo.utils.same(torch.ones(3, 3), out_graph()))
|
|
|
|
def test_export_meta(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = torch.nn.Parameter(torch.ones(2, 3))
|
|
|
|
def forward(self, x):
|
|
return self.p + x
|
|
|
|
with torch.device("meta"):
|
|
m = MyModule()
|
|
|
|
inp = torch.ones(2, 3, device="meta")
|
|
exported = torch._dynamo.export(m)(inp)
|
|
out_graph = exported[0]
|
|
dynamo_result = out_graph(inp)
|
|
self.assertEqual(dynamo_result, m(inp))
|
|
|
|
def test_constraint_violation_error_messages(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
if x.shape[0] == x.shape[1] * 2:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
foo = Foo()
|
|
|
|
t = torch.zeros([8, 4])
|
|
dim0 = torch.export.Dim("dim0", min=3, max=10)
|
|
dim1 = torch.export.Dim("dim1")
|
|
dynamic_shapes = {"x": (dim0, dim1)}
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UserError,
|
|
"Constraints violated .*!(.*\n)*.*"
|
|
"by dim0 = 2\\*dim1(.*\n)*.*"
|
|
"Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*",
|
|
):
|
|
torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes, strict=True)
|
|
|
|
class Bar(torch.nn.Module):
|
|
def forward(self, x):
|
|
if x.shape[0] == 5:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
bar = Bar()
|
|
|
|
t = torch.zeros([5])
|
|
dim0 = torch.export.Dim("dim0", min=3, max=8)
|
|
dynamic_shapes = {"x": (dim0,)}
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UserError,
|
|
"Not all values.*valid.*inferred to be a constant",
|
|
):
|
|
torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes, strict=True)
|
|
|
|
class Qux(torch.nn.Module):
|
|
def forward(self, x):
|
|
if x.shape[0] > 5 and x.shape[0] < 10:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
qux = Qux()
|
|
|
|
t = torch.zeros([7])
|
|
dim0 = torch.export.Dim("dim0", min=3, max=8)
|
|
dynamic_shapes = {"x": (dim0,)}
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UserError,
|
|
"Not all values.*satisfy the generated guard",
|
|
):
|
|
torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes, strict=True)
|
|
|
|
def test_untracked_inputs_in_constraints(self):
|
|
from copy import copy
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return y + 1
|
|
|
|
foo = Foo()
|
|
|
|
x = torch.randn(2)
|
|
y = torch.randn(5, 4)
|
|
|
|
dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y")
|
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
|
|
|
example_inputs = (copy(x), y)
|
|
ep = torch.export.export(
|
|
foo, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
|
|
)
|
|
ep.module()(torch.randn(3), y) # no specialization error
|
|
|
|
def test_export_raise_guard_full_constraint(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x):
|
|
if x.shape[0] == 3:
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
torch._dynamo.export(my_dyn_fn)(y)
|
|
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.export(
|
|
my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
|
|
)(y)
|
|
|
|
def test_export_module_specify_constraints_signature(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, x):
|
|
if x.shape[0] == 3:
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
mod = Mod()
|
|
torch._dynamo.export(mod)(y)
|
|
|
|
with self.assertRaisesRegex(ConstraintViolationError, "dimx = 3"):
|
|
torch._dynamo.export(mod, dynamic_shapes=({0: torch.export.Dim("dimx")},))(
|
|
y
|
|
)
|
|
|
|
def test_export_raise_guard_partial_constraint(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x):
|
|
if x.shape[0] > 3:
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
torch._dynamo.export(my_dyn_fn)(y)
|
|
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.export(
|
|
my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
|
|
)(y)
|
|
|
|
def test_export_raise_on_relationship(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(a, b, c):
|
|
if a.shape[0] == b.shape[1] == c.shape[2]:
|
|
return a.sin()
|
|
|
|
return a.cos()
|
|
|
|
torch._dynamo.export(my_dyn_fn)(y, y, y)
|
|
dim = torch.export.Dim("dim")
|
|
dynamic_shapes = ({0: dim}, {0: dim}, {0: dim})
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
|
|
dynamic_shapes = ({0: dim}, {1: dim}, {2: dim})
|
|
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
|
|
|
|
def test_export_no_raise(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(a, b, c):
|
|
if a.shape[1] == 3:
|
|
return a.cos()
|
|
return a * b * c
|
|
|
|
torch._dynamo.export(my_dyn_fn)(y, y, y)
|
|
dim = torch.export.Dim("dim")
|
|
dynamic_shapes = ({0: dim}, {0: dim}, {0: dim})
|
|
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(y, y, y)
|
|
|
|
def test_export_multi_dynamic_dim_unsafe_relationship(self):
|
|
x = torch.randn([3, 3, 3])
|
|
y = torch.randn([2, 2, 2])
|
|
z = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(a, b, c):
|
|
if a.shape[0] == c.shape[0]:
|
|
return a.cos()
|
|
return a * c, b
|
|
|
|
torch._dynamo.export(my_dyn_fn)(x, y, z)
|
|
dimx, dimy, dimz = torch.export.dims("dimx", "dimy", "dimz")
|
|
dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz})
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
|
|
dimz = dimx
|
|
dynamic_shapes = ({0: dimx}, {0: dimy}, {0: dimz})
|
|
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
|
|
|
|
def test_remove_redundant_dynamic_dim_in_error_message(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
if x.shape[0] == y["k"].shape[0]:
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
foo = Foo()
|
|
|
|
a = torch.randn(3)
|
|
b = torch.randn(3)
|
|
dim0_a, dim0_b = torch.export.dims("dim0_a", "dim0_b")
|
|
with self.assertRaisesRegex(torch._dynamo.exc.UserError, "dim0_b = dim0_a"):
|
|
torch.export.export(
|
|
foo,
|
|
(a, {"k": b}),
|
|
dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}},
|
|
strict=True,
|
|
)
|
|
|
|
def test_enforce_equalities(self):
|
|
class Bar(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.matmul(x, y)
|
|
|
|
bar = Bar()
|
|
|
|
batch, size = torch.export.dims("batch", "size")
|
|
dynamic_shapes = {"x": (batch, size, size), "y": (batch, size, size)}
|
|
|
|
x = torch.randn(10, 3, 3)
|
|
y = torch.randn(10, 3, 4)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UserError,
|
|
".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3",
|
|
):
|
|
torch.export.export(bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True)
|
|
y = torch.randn(10, 3, 3)
|
|
ebar = torch.export.export(
|
|
bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True
|
|
)
|
|
self.assertEqual(
|
|
[
|
|
str(node.meta["val"].shape)
|
|
for node in ebar.graph_module.graph.nodes
|
|
if node.op == "placeholder"
|
|
],
|
|
["torch.Size([s17, s27, s27])", "torch.Size([s17, s27, s27])"],
|
|
)
|
|
|
|
@torch._dynamo.config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
specialize_int=True,
|
|
capture_scalar_outputs=True,
|
|
)
|
|
def test_export_preserve_constraints_as_metadata_tensor(self):
|
|
def f(x):
|
|
b = x.nonzero()
|
|
torch._check(b.shape[0] >= 2)
|
|
torch._check(b.shape[0] <= 5)
|
|
return b
|
|
|
|
y = torch.tensor([8, 8, 6])
|
|
torch._dynamo.export(
|
|
f,
|
|
aten_graph=True,
|
|
tracing_mode="symbolic",
|
|
)(y)
|
|
|
|
@config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
specialize_int=True,
|
|
capture_scalar_outputs=True,
|
|
)
|
|
def test_exported_graph_serialization(self):
|
|
def f(x, y):
|
|
b = x.item()
|
|
torch._check_is_size(b)
|
|
return torch.empty((b, y.shape[0]))
|
|
|
|
x = torch.tensor([3])
|
|
y = torch.randn([8, 8, 6])
|
|
example_inputs = [x, y]
|
|
dynamic_shapes = (None, {0: torch.export.Dim("dimy", min=6, max=10)})
|
|
gm, _ = torch._dynamo.export(
|
|
f,
|
|
dynamic_shapes=dynamic_shapes,
|
|
aten_graph=True,
|
|
tracing_mode="symbolic",
|
|
)(*example_inputs)
|
|
|
|
# Ensure the exported graph module with metadata is serializable,
|
|
# metadata won't be saved in the serialized module
|
|
buffer = io.BytesIO()
|
|
torch.save(gm, buffer)
|
|
|
|
def test_export_dynamic_dim_not_1(self):
|
|
x = torch.randn([1, 1, 1])
|
|
|
|
def my_dyn_fn(a):
|
|
if a.shape[0] != 1:
|
|
return a.cos()
|
|
return a * a
|
|
|
|
torch._dynamo.export(my_dyn_fn)(x)
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.export(
|
|
my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dimx")},)
|
|
)(x)
|
|
|
|
def test_symbool(self):
|
|
def f(x):
|
|
a = torch.scalar_tensor(x.shape[0] > 4)
|
|
return x.sin().sum() + a.sum()
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
|
|
self.assertEqual(gm(torch.ones(3, 4)), f(torch.ones(3, 4)))
|
|
|
|
def test_export_multi_dynamic_dim_constraint(self):
|
|
x = torch.randn([3, 3, 3])
|
|
y = torch.randn([2, 2, 2])
|
|
z = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(a, b, c):
|
|
if a.shape[0] == c.shape[0]:
|
|
return a.cos()
|
|
return a * c, b
|
|
|
|
torch._dynamo.export(my_dyn_fn)(x, y, z)
|
|
dimx_0, dimx_1, dimx_2 = torch.export.dims("dimx_0", "dimx_1", "dimx_2")
|
|
dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, None)
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
|
|
dynamic_shapes = ({0: dimx_0, 1: dimx_1, 2: dimx_2}, None, {0: dimx_0})
|
|
torch._dynamo.export(my_dyn_fn, dynamic_shapes=dynamic_shapes)(x, y, z)
|
|
|
|
def test_export_dynamic_dim_range_constraint(self):
|
|
x = torch.ones(6, 4, 4)
|
|
dynamic_shapes = ({0: torch.export.Dim("dimx", min=5, max=6)},)
|
|
|
|
def foo(x):
|
|
if x.shape[0] > 3: # ok
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
torch._dynamo.export(
|
|
foo,
|
|
dynamic_shapes=dynamic_shapes,
|
|
aten_graph=True,
|
|
)(x)
|
|
|
|
def bar(x):
|
|
if x.shape[0] > 5: # error
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.export(
|
|
bar,
|
|
dynamic_shapes=dynamic_shapes,
|
|
aten_graph=True,
|
|
)(x)
|
|
|
|
def test_trivial_constraint(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
# complex divisibility condition
|
|
if (2 * x.shape[0] + 3) % (x.shape[0] - 3) == 0:
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
foo = Foo()
|
|
|
|
class Bar(torch.nn.Module):
|
|
def forward(self, x):
|
|
# trivially true
|
|
if (2 * x.shape[0] + 2) % (x.shape[0] + 1) == 0:
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
bar = Bar()
|
|
|
|
class Qux(torch.nn.Module):
|
|
def forward(self, x):
|
|
# simple divisibility condition (not trivially true)
|
|
if (3 * x.shape[0]) % 2 == 0:
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
qux = Qux()
|
|
|
|
x = torch.randn(12)
|
|
dim0 = torch.export.Dim("dim0", max=100)
|
|
dynamic_shapes = {"x": (dim0,)}
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UserError,
|
|
r"Constraints violated \(dim0\)",
|
|
):
|
|
torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes, strict=True)
|
|
|
|
torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes, strict=True)
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UserError,
|
|
r"Constraints violated \(dim0\)",
|
|
):
|
|
torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes, strict=True)
|
|
|
|
def test_list_contains(self):
|
|
def func(x):
|
|
assert x.size(-1) in [4, 5, 6], "bad"
|
|
return x + x
|
|
|
|
inps = (torch.randn(1, 5),)
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_list_not_contains(self):
|
|
def func(x):
|
|
assert x.size(0) not in [4, 5, 6], "bad1"
|
|
assert "monkey" not in ["cow", "pig"], "bad2"
|
|
return x + x
|
|
|
|
inps = (torch.randn(1, 5),)
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True, dynamic=True)
|
|
real_result = opt_func(*inps)
|
|
|
|
torch._dynamo.reset()
|
|
|
|
exported = torch._dynamo.export(func, aten_graph=True)(*inps)
|
|
out_graph = exported[0]
|
|
|
|
dynamo_result = out_graph(*inps)
|
|
|
|
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))
|
|
|
|
def test_export_identity(self):
|
|
inp = torch.tensor([0.1, 0.1])
|
|
|
|
def func(x):
|
|
return x
|
|
|
|
torch._dynamo.reset()
|
|
exported, _ = torch._dynamo.export(func)(inp)
|
|
dynamo_result = exported(inp)
|
|
self.assertTrue(torch._dynamo.utils.same(inp, dynamo_result))
|
|
|
|
def test_export_specialized_int(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_dim,
|
|
):
|
|
super().__init__()
|
|
self.torch_module = torch.nn.LayerNorm(
|
|
input_dim, eps=1e-5, elementwise_affine=True
|
|
)
|
|
self.int_val = 100
|
|
|
|
def forward(self, input):
|
|
return input.cos() * self.int_val * self.torch_module.eps
|
|
|
|
mod = Foo(128)
|
|
inp = torch.randn(3, 128)
|
|
|
|
# In export, int & float in forward should always be specialized
|
|
gm, _ = torch._dynamo.export(mod, aten_graph=True)(inp)
|
|
count = 0
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
count += 1
|
|
self.assertEqual(count, 1)
|
|
|
|
def test_export_with_nonzero_static(self):
|
|
class BasicModule(torch.nn.Module):
|
|
def __init__(self, static_size):
|
|
super().__init__()
|
|
self.static_size = static_size
|
|
|
|
def forward(self, x):
|
|
return torch.nonzero_static(x, size=self.static_size)
|
|
|
|
input_tensors = torch.tensor([6, 8]), torch.zeros(2, 3)
|
|
static_sizes = 3, 4
|
|
for input_tensor, static_size in zip(input_tensors, static_sizes):
|
|
m = BasicModule(static_size)
|
|
gm, _ = torch._dynamo.export(m, aten_graph=True)(input_tensor)
|
|
res = gm(input_tensor)
|
|
self.assertEqual(res.size(0), static_size)
|
|
self.assertTrue(
|
|
torch._dynamo.utils.same(
|
|
res, torch.nonzero_static(input_tensor, size=static_size)
|
|
)
|
|
)
|
|
|
|
def test_export_pass_arg_by_name(self):
|
|
class BasicModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.my_lin = torch.nn.Linear(3, 4, bias=True)
|
|
|
|
def forward(self, x):
|
|
return self.my_lin(x)
|
|
|
|
mod, input_tensor = BasicModule(), torch.randn(2, 3)
|
|
gm, _ = torch._dynamo.export(mod, aten_graph=True)(input_tensor)
|
|
ref = mod(x=input_tensor)
|
|
res = gm(x=input_tensor)
|
|
self.assertTrue(torch._dynamo.utils.same(ref, res))
|
|
|
|
def test_export_pass_arg_by_name_star_args(self):
|
|
class BasicModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.my_lin = torch.nn.Linear(3, 4, bias=True)
|
|
|
|
def forward(self, *args):
|
|
return self.my_lin(args[0]) * self.my_lin(args[1])
|
|
|
|
mod, input_tensor, input_tensor2 = (
|
|
BasicModule(),
|
|
torch.randn(2, 3),
|
|
torch.randn(2, 3),
|
|
)
|
|
gm, _ = torch._dynamo.export(mod, aten_graph=True)(input_tensor, input_tensor2)
|
|
ref = mod(input_tensor, input_tensor2)
|
|
res = gm(input_tensor, input_tensor2)
|
|
self.assertTrue(torch._dynamo.utils.same(ref, res))
|
|
|
|
def test_export_dynamic_dim_cleanup(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x):
|
|
return x.cos()
|
|
|
|
torch._dynamo.export(my_dyn_fn, dynamic_shapes=({0: torch.export.Dim("dim")},))(
|
|
y
|
|
)
|
|
|
|
@config.patch(capture_dynamic_output_shape_ops=True)
|
|
def test_export_dynamic_control_flow_error(self):
|
|
def f(x):
|
|
if x.nonzero() > 3:
|
|
return x.cos()
|
|
return x.sin()
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Data-dependent branching",
|
|
):
|
|
torch._dynamo.export(f, aten_graph=True)(torch.randn(5, 6))
|
|
|
|
@config.patch(assume_static_by_default=False)
|
|
def test_export_persist_assert(self):
|
|
def f(x):
|
|
assert x[0].sum() > 4, "Shape must be more than 4"
|
|
return x.cos() + x.sin()
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
|
|
torch.ones(5, 4, 6)
|
|
)
|
|
|
|
def has_aten_op(gm, op):
|
|
for node in gm.graph.nodes:
|
|
if node.target == op:
|
|
return True
|
|
return False
|
|
|
|
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
|
|
|
|
gm.graph.eliminate_dead_code()
|
|
gm.recompile()
|
|
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
|
|
gm(torch.zeros(3, 4, 5))
|
|
|
|
@common_utils.parametrize(
|
|
"type_fn",
|
|
[
|
|
common_utils.subtest(type, name="builtin"),
|
|
common_utils.subtest(lambda obj: obj.__class__, name="attr"),
|
|
],
|
|
)
|
|
def test_access_class_method_from_user_class(self, type_fn):
|
|
class A:
|
|
@classmethod
|
|
def func(cls):
|
|
return torch.Tensor([4, 5])
|
|
|
|
def f(x):
|
|
a = A()
|
|
return x.sum() + type_fn(a).func().sum()
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
|
|
self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))
|
|
|
|
def test_not_functionalize(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer1 = torch.nn.Buffer(torch.ones(6, 2))
|
|
|
|
def forward(self, x):
|
|
x.add_(2)
|
|
return x.sum() + self.buffer1.sum()
|
|
|
|
example_inputs = (torch.ones(1, 2, 3),)
|
|
gm, _ = torch._dynamo.export(
|
|
Foo(),
|
|
aten_graph=True,
|
|
tracing_mode="symbolic",
|
|
)(*example_inputs)
|
|
count = 0
|
|
for node in gm.graph.nodes:
|
|
if node.target == torch.ops.aten.add_.Tensor:
|
|
count += 1
|
|
self.assertEqual(count, 1)
|
|
test_inp = (torch.ones(1, 2, 3),)
|
|
test_inp_v2 = (torch.ones(1, 2, 3),)
|
|
self.assertEqual(gm(*test_inp), Foo()(*test_inp_v2))
|
|
|
|
def test_round_dynamic_shapes(self):
|
|
def f(x):
|
|
return x[: round(x.shape[0] / 2)]
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
|
|
|
|
self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))
|
|
|
|
def test_cond_supported_pred_types(self):
|
|
def true_fn(x):
|
|
return x.cos()
|
|
|
|
def false_fn(x):
|
|
return x.sin()
|
|
|
|
def f_pred_traced_as_symnode_var(x):
|
|
return cond(x.shape[0] > 2, true_fn, false_fn, [x])
|
|
|
|
def f_pred_traced_as_tensor_var(x):
|
|
return cond(x.all(), true_fn, false_fn, [x])
|
|
|
|
def f_pred_complex_expression_traced_as_symnode_var(x):
|
|
return cond(
|
|
x.dim() > 1 and x.shape[1] > 5 and x.shape[1] <= 10,
|
|
true_fn,
|
|
false_fn,
|
|
[x],
|
|
)
|
|
|
|
example_inputs = (torch.rand(5, 8),)
|
|
for f in [
|
|
f_pred_traced_as_symnode_var,
|
|
f_pred_traced_as_tensor_var,
|
|
f_pred_complex_expression_traced_as_symnode_var,
|
|
]:
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs)
|
|
self.assertEqual(gm(*example_inputs), f(*example_inputs))
|
|
|
|
@unittest.expectedFailure # TODO: Not sure why dynamo creates a new inputs for self.a
|
|
def test_sum_param(self):
|
|
# Setting a new attribute inside forward()
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = torch.randn(3, 2)
|
|
|
|
def forward(self, x):
|
|
self.b = 2
|
|
return x.sum() + self.a.sum() + self.b
|
|
|
|
torch._dynamo.export(Foo())(torch.randn(3, 2))
|
|
|
|
def test_mixed_real_and_fake_inputs(self):
|
|
class _TestPattern(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(1, 1, 1)
|
|
self.bn = torch.nn.BatchNorm2d(1)
|
|
|
|
def forward(self, input):
|
|
running_std = torch.sqrt(self.bn.running_var + self.bn.eps)
|
|
scale_factor = self.bn.weight / running_std
|
|
weight_shape = [1] * len(self.conv.weight.shape)
|
|
weight_shape[0] = -1
|
|
bias_shape = [1] * len(self.conv.weight.shape)
|
|
bias_shape[1] = -1
|
|
scaled_weight = self.conv.weight * scale_factor.reshape(weight_shape)
|
|
zero_bias = torch.zeros_like(self.conv.bias, dtype=input.dtype)
|
|
conv = self.conv._conv_forward(input, scaled_weight, zero_bias)
|
|
conv_orig = conv / scale_factor.reshape(bias_shape)
|
|
conv_orig = conv_orig + self.conv.bias.reshape(bias_shape)
|
|
conv = self.bn(conv_orig)
|
|
return conv
|
|
|
|
example_inputs = (torch.randn(1, 1, 3, 3),)
|
|
torch._dynamo.export(
|
|
_TestPattern(),
|
|
aten_graph=True,
|
|
)(*example_inputs)
|
|
|
|
@config.patch(
|
|
capture_dynamic_output_shape_ops=True,
|
|
capture_scalar_outputs=True,
|
|
assume_static_by_default=False,
|
|
)
|
|
def test_sym_contains(self):
|
|
def f(x, y):
|
|
return x.size(0) in y
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(2), torch.ones(3))
|
|
|
|
true_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(5))
|
|
false_inp = (torch.Tensor([6, 4, 5]), torch.ones(6, 4).add_(2))
|
|
self.assertEqual(gm(*true_inp), f(*true_inp))
|
|
self.assertEqual(gm(*false_inp), f(*false_inp))
|
|
|
|
def test_cond_raise_user_error_on_missing_args(self):
|
|
def true_fn(x):
|
|
return x.cos()
|
|
|
|
def false_fn(x):
|
|
return x.sin()
|
|
|
|
def f(x):
|
|
return cond(x.shape[0] > 10, true_fn, false_fn)
|
|
|
|
# Now we allow torch.cond to handle empty args
|
|
example_inputs = (torch.rand(5),)
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
r"false_fn\(\) missing 1 required positional argument: 'x'",
|
|
):
|
|
f(*example_inputs)
|
|
|
|
def test_cond_raise_user_error_on_unsupported_pred(self):
|
|
def f_unsupported_pred(x):
|
|
pred = torch.nn.Module()
|
|
return cond(pred, lambda x: x.sin(), lambda x: x.cos(), [x])
|
|
|
|
example_inputs = (torch.rand(5),)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected pred to be bool or tensor, but got Module()",
|
|
):
|
|
f_unsupported_pred(*example_inputs)
|
|
|
|
def test_cond_raise_user_error_on_non_list_operands(self):
|
|
def f_non_list_operands(x):
|
|
return cond(torch.tensor(True), lambda x: x.sin(), lambda x: x.cos(), x)
|
|
|
|
example_inputs = (torch.rand(5),)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
|
):
|
|
f_non_list_operands(*example_inputs)
|
|
|
|
def test_cond_raise_user_error_on_non_tensor_operands(self):
|
|
def f_non_tensor_operands(x):
|
|
a: float = 3.14
|
|
return cond(
|
|
torch.tensor(1234), lambda x, a: x.sin(), lambda x, a: x.cos(), [x, a]
|
|
)
|
|
|
|
example_inputs = (torch.rand(5),)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Expect operands to be a tuple of possibly nested dict/list/tuple",
|
|
):
|
|
f_non_tensor_operands(*example_inputs)
|
|
|
|
def test_cond_raise_user_error_on_branch_args_mismatch(self):
|
|
def true_fn(x, y):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f_branch_args_mismatch(x, y):
|
|
return cond(torch.tensor([[[[True]]]]), true_fn, false_fn, [x, y])
|
|
|
|
example_inputs = (torch.rand(5), torch.rand(2))
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compil",
|
|
):
|
|
torch._dynamo.export(
|
|
f_branch_args_mismatch,
|
|
aten_graph=True,
|
|
)(
|
|
*example_inputs,
|
|
)
|
|
|
|
@config.patch(suppress_errors=True)
|
|
def test_uncaptured_higher_order_op_error_not_suppresed(self):
|
|
def true_fn(x, y):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f_branch_args_mismatch(x, y):
|
|
return cond(torch.tensor([[[[100]]]]), true_fn, false_fn, [x, y])
|
|
|
|
example_inputs = (torch.rand(5), torch.rand(2))
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
torch._dynamo.export(
|
|
f_branch_args_mismatch,
|
|
aten_graph=True,
|
|
)(
|
|
*example_inputs,
|
|
)
|
|
|
|
def test_cond_raise_user_error_on_branch_return_non_tensor(self):
|
|
def f_branch_return_non_tensor(x):
|
|
return cond(x.shape[0] <= 5, lambda x: 3.14, lambda x: 3.14, [x])
|
|
|
|
example_inputs = (torch.rand(5),)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UncapturedHigherOrderOpError,
|
|
"Cond doesn't work unless it is captured completely with torch.compile",
|
|
):
|
|
torch._dynamo.export(
|
|
f_branch_return_non_tensor,
|
|
aten_graph=True,
|
|
)(*example_inputs)
|
|
|
|
def test_cond_raise_user_error_on_branch_return_multiple_tensors(self):
|
|
def f_branch_return_multiple_tensors(pred, x, y):
|
|
return cond(pred, lambda x: (x, x), lambda x: (x, x), [y])
|
|
|
|
example_inputs = (torch.tensor(True), torch.randn(4), torch.randn(2))
|
|
gm, _ = torch._dynamo.export(
|
|
f_branch_return_multiple_tensors,
|
|
aten_graph=True,
|
|
)(*example_inputs)
|
|
self.assertEqual(
|
|
gm(*example_inputs), f_branch_return_multiple_tensors(*example_inputs)
|
|
)
|
|
|
|
def test_multiple_outputs_op_with_evaluator(self):
|
|
class TopKModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
values, _ = torch.topk(x, 3)
|
|
return torch.sum(values)
|
|
|
|
x = torch.arange(1.0, 6.0, requires_grad=True)
|
|
torch._dynamo.export(TopKModel())(x)
|
|
|
|
def test_cond_raise_user_error_on_mismatch_return_length(self):
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
return (x, x)
|
|
|
|
def f_mismatch_return_length(x):
|
|
return cond(torch.tensor(100), true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.rand(5),)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"Unmatched output spec from torch.cond branches",
|
|
):
|
|
torch._dynamo.export(
|
|
f_mismatch_return_length,
|
|
aten_graph=True,
|
|
)(*example_inputs)
|
|
|
|
def test_cond_raise_user_error_on_mismatch_return_tensor_meta(self):
|
|
def true_fn(x):
|
|
return torch.tensor([[3], [2]])
|
|
|
|
def false_fn(x):
|
|
return torch.tensor([3.14])
|
|
|
|
def f_return_tensor_mismatch(x):
|
|
return cond(x.shape[0] < 3, true_fn, false_fn, [x])
|
|
|
|
example_inputs = (torch.rand(5),)
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.TorchRuntimeError,
|
|
"When merging two branches' output in torch.cond",
|
|
):
|
|
torch._dynamo.export(f_return_tensor_mismatch, aten_graph=True)(
|
|
*example_inputs,
|
|
)
|
|
|
|
def test_byte_tensor_does_not_crash(self):
|
|
# See https://github.com/pytorch/pytorch/issues/100455
|
|
def func(text):
|
|
tensor = torch.ByteTensor(list(bytes(text, "utf8")))
|
|
return tensor + tensor
|
|
|
|
text = "".join(chr(a % 90 + 40) for a in range(111))
|
|
opt_func = torch.compile(func, backend="eager", dynamic=True)
|
|
for i in [99, 100]:
|
|
input = text[:i]
|
|
opt_func(input)
|
|
|
|
def test_export_defaults_ok(self):
|
|
class DynamicSliceExportMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
results = []
|
|
for i in range(4):
|
|
results.append(x[: x.size(0) - i, i : x.size(2), i:3])
|
|
return tuple(results)
|
|
|
|
gm, _ = torch._dynamo.export(DynamicSliceExportMod(), aten_graph=True)(
|
|
torch.randn(5, 5, 5),
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
arg0_1 = arg0
|
|
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
|
|
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
|
|
sub = sym_size_int - 1
|
|
slice_2 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub); sub = None
|
|
slice_3 = torch.ops.aten.slice.Tensor(slice_2, 1, 1, sym_size_int); slice_2 = None
|
|
slice_4 = torch.ops.aten.slice.Tensor(slice_3, 2, 1, 3); slice_3 = None
|
|
sub_1 = sym_size_int - 2
|
|
slice_5 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_1); sub_1 = None
|
|
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 2, sym_size_int); slice_5 = None
|
|
slice_7 = torch.ops.aten.slice.Tensor(slice_6, 2, 2, 3); slice_6 = None
|
|
sub_2 = sym_size_int - 3
|
|
slice_8 = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, sub_2); arg0_1 = sub_2 = None
|
|
slice_9 = torch.ops.aten.slice.Tensor(slice_8, 1, 3, sym_size_int); slice_8 = sym_size_int = None
|
|
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 2, 3, 3); slice_9 = None
|
|
return pytree.tree_unflatten([slice_1, slice_4, slice_7, slice_10], self._out_spec)""",
|
|
)
|
|
|
|
def test_capture_symbolic_tracing_simple_within_fake_mode(self):
|
|
from torch._dynamo.output_graph import config
|
|
|
|
def f(x):
|
|
y = torch.randn(3)
|
|
return x + x * y
|
|
|
|
with fake_tensor.FakeTensorMode(
|
|
shape_env=ShapeEnv(
|
|
allow_scalar_outputs=config.capture_scalar_outputs,
|
|
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
|
),
|
|
):
|
|
x = torch.randn(3)
|
|
|
|
for aten_graph in [True, False]:
|
|
gm, _ = torch._dynamo.export(f, aten_graph=aten_graph)(x)
|
|
self.assertTrue(
|
|
isinstance(gm, torch.fx.GraphModule),
|
|
msg="test_capture_symbolic_tracing_simple_within_fake_mode_aten_graph_"
|
|
+ str(aten_graph),
|
|
)
|
|
|
|
def test_export_with_symbool_inputs(self):
|
|
def f(pred: bool, x: torch.Tensor):
|
|
if pred:
|
|
return x.sin()
|
|
else:
|
|
return x.cos()
|
|
|
|
x = torch.randn([3, 4])
|
|
|
|
def test_symbool_guards(
|
|
f, size_tests, exp_graph, exp_guard_code, exp_shape_env_guards
|
|
):
|
|
shape_env = ShapeEnv()
|
|
with fake_tensor.FakeTensorMode(
|
|
shape_env=shape_env,
|
|
) as fake_mode:
|
|
fake_x = fake_mode.from_tensor(
|
|
x,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=[DimDynamic.DYNAMIC for _ in range(x.dim())],
|
|
),
|
|
)
|
|
for i, size in enumerate(size_tests):
|
|
pred = fake_x.size(0) == size
|
|
gm, guards = torch._dynamo.export(f)(pred, x)
|
|
actual = normalize_gm(gm.print_readable(print_output=False))
|
|
# TODO: This is naughty, EXPECTTEST_ACCEPT=1 doesn't work
|
|
self.assertExpectedInline(actual, exp_graph[i])
|
|
dynamo_shape_env_guards = [
|
|
guard
|
|
for guard in guards
|
|
if guard.guard_types is not None
|
|
and "SHAPE_ENV" in guard.guard_types
|
|
]
|
|
self.assertEqual(len(dynamo_shape_env_guards), 1)
|
|
guard_code_on_predicate = [
|
|
code
|
|
for code in dynamo_shape_env_guards[0].code_list
|
|
if "L['pred']" in code
|
|
]
|
|
self.assertEqual(guard_code_on_predicate, exp_guard_code[i])
|
|
outter_shape_env_guards = [
|
|
str(guard.expr) for guard in shape_env.guards
|
|
]
|
|
self.assertEqual(outter_shape_env_guards, exp_shape_env_guards[i])
|
|
|
|
true_graph = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, pred, x):
|
|
arg1: "f32[s77, s27]";
|
|
|
|
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
|
|
l_x_ = arg1
|
|
|
|
sin: "f32[s77, s27]" = l_x_.sin(); l_x_ = None
|
|
return pytree.tree_unflatten([sin], self._out_spec)
|
|
"""
|
|
false_graph = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, pred, x):
|
|
arg1: "f32[s77, s27]";
|
|
|
|
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
|
|
l_x_ = arg1
|
|
|
|
cos: "f32[s77, s27]" = l_x_.cos(); l_x_ = None
|
|
return pytree.tree_unflatten([cos], self._out_spec)
|
|
"""
|
|
true_guard_code = [
|
|
"cast_symbool_to_symint_guardless(L['pred']) == 1",
|
|
]
|
|
false_guard_code = [
|
|
"cast_symbool_to_symint_guardless(L['pred']) != 1",
|
|
]
|
|
test_symbool_guards(
|
|
f,
|
|
[3, 3, 4, 5],
|
|
[true_graph, true_graph, false_graph, false_graph],
|
|
[true_guard_code, true_guard_code, false_guard_code, false_guard_code],
|
|
# Outter shape env should have no guards in it because we never specialize on the outter symbool.
|
|
[[], [], [], []],
|
|
)
|
|
|
|
def test_invalid_input_global(self) -> None:
|
|
global bulbous_bouffant
|
|
bulbous_bouffant = torch.randn(3)
|
|
|
|
def f(y):
|
|
return bulbous_bouffant + y
|
|
|
|
self.assertExpectedInlineMunged(
|
|
UserError,
|
|
lambda: torch._dynamo.export(f)(torch.randn(3)),
|
|
"""\
|
|
G['bulbous_bouffant'], accessed at:
|
|
File "test_export.py", line N, in f
|
|
return bulbous_bouffant + y
|
|
""",
|
|
)
|
|
|
|
def test_invalid_input_global_multiple_access(self) -> None:
|
|
global macademia
|
|
macademia = torch.randn(3)
|
|
|
|
def g(y):
|
|
global macademia
|
|
y = macademia + y
|
|
return y
|
|
|
|
def f(y):
|
|
global macademia
|
|
y = g(y)
|
|
return macademia + y
|
|
|
|
# NB: This doesn't actually work (it only reports the first usage),
|
|
# but I'm leaving the test here in case we fix it later
|
|
self.assertExpectedInlineMunged(
|
|
UserError,
|
|
lambda: torch._dynamo.export(f)(torch.randn(3)),
|
|
"""\
|
|
G['macademia'], accessed at:
|
|
File "test_export.py", line N, in f
|
|
y = g(y)
|
|
File "test_export.py", line N, in g
|
|
y = macademia + y
|
|
""",
|
|
)
|
|
|
|
def test_invalid_input_nonlocal(self) -> None:
|
|
arglebargle = torch.randn(3)
|
|
|
|
def f(y):
|
|
return arglebargle + y
|
|
|
|
self.assertExpectedInlineMunged(
|
|
UserError,
|
|
lambda: torch._dynamo.export(f)(torch.randn(3)),
|
|
"""L['arglebargle'], a closed over free variable""",
|
|
)
|
|
|
|
def test_invalid_input_unused_nonlocal_ok(self) -> None:
|
|
arglebargle = torch.randn(3)
|
|
|
|
def f(y):
|
|
x = arglebargle # noqa: F841
|
|
return y
|
|
|
|
torch._dynamo.export(f)(torch.randn(3))
|
|
|
|
def test_symbolic_tracing_within_fake_mode_with_constraints(self):
|
|
from torch._subclasses import fake_tensor
|
|
|
|
fake_mode = fake_tensor.FakeTensorMode()
|
|
|
|
class DynamicShapeSimpleModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b, c) -> torch.Tensor:
|
|
d = (torch.matmul(a, b) + c) / 2
|
|
d_s0 = d.shape[0]
|
|
d_s1 = d.shape[1]
|
|
d_s3 = d_s0 * d_s1
|
|
e = d.view(d_s3)
|
|
return torch.cat([e, e])
|
|
|
|
with fake_mode:
|
|
model = DynamicShapeSimpleModel()
|
|
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
|
|
dim = torch.export.Dim("dim")
|
|
dynamic_shapes = ({0: dim}, None, {0: dim})
|
|
for aten_graph in [True, False]:
|
|
gm = torch._dynamo.export(
|
|
model,
|
|
dynamic_shapes=dynamic_shapes,
|
|
aten_graph=aten_graph,
|
|
)(*inputs).graph_module
|
|
|
|
# Since there are no parameters we can do this
|
|
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
|
|
self.assertEqual(model(*inputs), gm(*inputs))
|
|
|
|
def test_symbolic_tracing_within_fake_mode_with_constraints_with_parameters(self):
|
|
from torch._subclasses import fake_tensor
|
|
|
|
fake_mode = fake_tensor.FakeTensorMode()
|
|
|
|
# TODO: Seems to choke if you don't make a fresh model and
|
|
# just try to export Linear directly...
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
out = self.linear(x)
|
|
return out
|
|
|
|
with fake_mode:
|
|
model = Model()
|
|
inputs = (torch.randn(10, 2, 2),)
|
|
dynamic_shapes = ({0: torch.export.Dim("dim")},)
|
|
for aten_graph in [True, False]:
|
|
torch._dynamo.export(
|
|
model,
|
|
dynamic_shapes=dynamic_shapes,
|
|
aten_graph=aten_graph,
|
|
)(*inputs).graph_module
|
|
|
|
def test_capture_symbolic_tracing_within_fake_mode(self):
|
|
from torch._dynamo.output_graph import config
|
|
from torch._subclasses import fake_tensor
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
self.linear2 = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
out = self.linear(x)
|
|
out = self.linear2(out)
|
|
return out
|
|
|
|
# User-instantiated FakeTensorMode
|
|
fake_mode = fake_tensor.FakeTensorMode(
|
|
allow_non_fake_inputs=False,
|
|
allow_fallback_kernels=True,
|
|
shape_env=ShapeEnv(
|
|
allow_scalar_outputs=config.capture_scalar_outputs,
|
|
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
|
),
|
|
)
|
|
# Fakefy input+model before exporting it
|
|
with fake_mode:
|
|
x = torch.rand(5, 2, 2)
|
|
model = Model()
|
|
|
|
# Export the model with fake inputs and parameters
|
|
for aten_graph in [True, False]:
|
|
graph_module, _ = torch._dynamo.export(model, aten_graph=aten_graph)(x)
|
|
self.assertTrue(
|
|
isinstance(graph_module, torch.fx.GraphModule),
|
|
msg="test_capture_symbolic_tracing_within_fake_mode_aten_graph_"
|
|
+ str(aten_graph),
|
|
)
|
|
|
|
def test_cond_op_param_buffer_lifted(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))
|
|
|
|
def forward(self):
|
|
return self.buffer1.sum()
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self):
|
|
return self.buffer2.sum()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = A()
|
|
self.b = B()
|
|
|
|
def forward(self, x):
|
|
def true_fn(x):
|
|
return x.cos() + self.a()
|
|
|
|
def false_fn(x):
|
|
return x.sin() + self.b()
|
|
|
|
return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)
|
|
|
|
gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4))
|
|
self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
|
|
self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))
|
|
|
|
def test_nested_cond_op_param_buffer_lifted(self):
|
|
class A(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))
|
|
|
|
def forward(self):
|
|
return self.buffer1.sum()
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self):
|
|
return self.buffer2.sum()
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = A()
|
|
self.b = B()
|
|
|
|
def forward(self, x):
|
|
def true_true_fn(x):
|
|
return x.cos() + self.a()
|
|
|
|
def true_false_fn(x):
|
|
return x.cos() + self.a() + 1
|
|
|
|
def true_fn(x):
|
|
return cond(x.shape[0] > 5, true_true_fn, true_false_fn, [x])
|
|
|
|
def false_fn(x):
|
|
return x.sin() + self.b()
|
|
|
|
return (cond(x.shape[0] > 4, true_fn, false_fn, [x]),)
|
|
|
|
gm, _ = torch._dynamo.export(M(), aten_graph=False)(torch.ones(6, 4))
|
|
self.assertEqual(gm(torch.ones(6, 4)), M()(torch.ones(6, 4)))
|
|
self.assertEqual(gm(torch.ones(5, 4)), M()(torch.ones(5, 4)))
|
|
self.assertEqual(gm(torch.ones(3, 4)), M()(torch.ones(3, 4)))
|
|
|
|
def test_map_cond_param_buffer_lifted(self):
|
|
from functorch.experimental.control_flow import cond, map
|
|
|
|
class A(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer1 = torch.nn.Buffer(torch.zeros(6, 4))
|
|
|
|
def forward(self):
|
|
return self.buffer1.sum()
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer2 = torch.nn.Buffer(torch.ones(6, 4))
|
|
|
|
def forward(self):
|
|
return self.buffer2.sum()
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = A()
|
|
self.b = B()
|
|
|
|
def inner(self, x, pred):
|
|
def true_fn(x):
|
|
return x + x + self.a()
|
|
|
|
def false_fn(x):
|
|
return x * x + self.b()
|
|
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
def forward(self, pred, xs):
|
|
def body(x, pred):
|
|
return self.inner(x, pred) + self.b()
|
|
|
|
return map(body, xs, pred)
|
|
|
|
mod = Module()
|
|
x = torch.randn(3, 2, 1)
|
|
pred_x = torch.tensor(True)
|
|
|
|
y = torch.randn(4, 3, 2)
|
|
pred_y = torch.tensor(False)
|
|
real_result = mod(pred_y, y)
|
|
|
|
out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
|
|
self.assertEqual(real_result, out_graph(pred_y, y))
|
|
|
|
def test_cond_free_variables_overlapping(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, pred, x):
|
|
a = torch.ones(6, 4)
|
|
b = torch.ones(6, 4)
|
|
c = torch.ones(6, 4)
|
|
d = torch.ones(6, 4)
|
|
|
|
def true_fn(x):
|
|
return x + x + a.cos() + b.cos() + d.cos()
|
|
|
|
def false_fn(x):
|
|
return x * x + a.sin() + b.sin() + c.sin()
|
|
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
mod = Module()
|
|
x = torch.ones(6, 4)
|
|
pred_x = torch.tensor(True)
|
|
|
|
out_graph, _ = torch._dynamo.export(mod)(pred_x, x)
|
|
self.assertExpectedInline(
|
|
out_graph.code.strip(),
|
|
"""\
|
|
def forward(self, pred, x):
|
|
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
|
|
l_pred_ = arg0
|
|
l_x_ = arg1
|
|
a = torch.ones(6, 4)
|
|
b = torch.ones(6, 4)
|
|
c = torch.ones(6, 4)
|
|
d = torch.ones(6, 4)
|
|
cond_true_0 = self.cond_true_0
|
|
cond_false_0 = self.cond_false_0
|
|
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, (a, b, l_x_, d, c)); l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None
|
|
getitem = cond[0]; cond = None
|
|
return pytree.tree_unflatten([getitem], self._out_spec)""", # noqa: B950,E122
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
out_graph.cond_true_0.code.strip(),
|
|
"""\
|
|
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
|
|
a_1 = a
|
|
b_1 = b
|
|
l_x__1 = l_x_
|
|
add = l_x__1 + l_x__1; l_x__1 = None
|
|
cos = a_1.cos(); a_1 = None
|
|
add_1 = add + cos; add = cos = None
|
|
cos_1 = b_1.cos(); b_1 = None
|
|
add_2 = add_1 + cos_1; add_1 = cos_1 = None
|
|
cos_2 = d_true_branch.cos(); d_true_branch = None
|
|
add_3 = add_2 + cos_2; add_2 = cos_2 = None
|
|
return (add_3,)""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
out_graph.cond_false_0.code.strip(),
|
|
"""\
|
|
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
|
|
a_1 = a
|
|
b_1 = b
|
|
l_x__1 = l_x_
|
|
mul = l_x__1 * l_x__1; l_x__1 = None
|
|
sin = a_1.sin(); a_1 = None
|
|
add = mul + sin; mul = sin = None
|
|
sin_1 = b_1.sin(); b_1 = None
|
|
add_1 = add + sin_1; add = sin_1 = None
|
|
sin_2 = c_false_branch.sin(); c_false_branch = None
|
|
add_2 = add_1 + sin_2; add_1 = sin_2 = None
|
|
return (add_2,)""",
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
common_utils.TEST_WITH_ASAN,
|
|
"Times out with ASAN, see https://github.com/pytorch/pytorch/issues/110416",
|
|
)
|
|
def test_retracibility(self):
|
|
class MyLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.randn(20, 98)
|
|
self.bias = torch.randn(20)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.weight, self.bias)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(16, 33, 3)
|
|
self.linear = MyLinear()
|
|
|
|
def forward(self, x):
|
|
a, b = x
|
|
a_conv = self.conv(a)
|
|
a_linear = self.linear(a_conv)
|
|
b_conv = self.conv(b)
|
|
b_linear = self.linear(b_conv)
|
|
return (
|
|
a_linear.cos() + b_linear.sin(),
|
|
a_linear.sin() + b_linear.cos(),
|
|
)
|
|
|
|
inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
|
|
|
|
gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
|
|
gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
|
|
|
|
inp_test = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
|
|
|
|
self.assertTrue(torch.allclose(gm(inp_test)[0], gm2(inp_test)[0]))
|
|
self.assertTrue(torch.allclose(gm(inp_test)[1], gm2(inp_test)[1]))
|
|
|
|
def test_retracibility_dict_container_inp_out(self):
|
|
class MyLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.randn(20, 98)
|
|
self.bias = torch.randn(20)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.weight, self.bias)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(16, 33, 3)
|
|
self.linear = MyLinear()
|
|
|
|
def forward(self, x):
|
|
a1, a2 = x["a"]
|
|
b = x["b"]
|
|
a1_conv = self.conv(a1)
|
|
a1_linear = self.linear(a1_conv)
|
|
a2_conv = self.conv(a2)
|
|
a2_linear = self.linear(a2_conv)
|
|
b_conv = self.conv(b)
|
|
b_linear = self.linear(b_conv)
|
|
return {
|
|
"a": [
|
|
a1_linear.cos() + b_linear.sin(),
|
|
a1_linear.cos() + b_linear.sin(),
|
|
],
|
|
"b": a2_linear.sin() + b_linear.cos(),
|
|
}
|
|
|
|
inp_container = {
|
|
"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
|
|
"b": torch.randn(20, 16, 50, 100),
|
|
}
|
|
|
|
gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
|
|
gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
|
|
|
|
inp_test = {
|
|
"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
|
|
"b": torch.randn(20, 16, 50, 100),
|
|
}
|
|
|
|
self.assertTrue(torch.allclose(gm(inp_test)["a"][0], gm2(inp_test)["a"][0]))
|
|
self.assertTrue(torch.allclose(gm(inp_test)["a"][1], gm2(inp_test)["a"][1]))
|
|
self.assertTrue(torch.allclose(gm(inp_test)["b"], gm2(inp_test)["b"]))
|
|
|
|
def test_retracibility_nested_list_out(self):
|
|
class MyLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.randn(20, 98)
|
|
self.bias = torch.randn(20)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.weight, self.bias)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(16, 33, 3)
|
|
self.linear = MyLinear()
|
|
|
|
def forward(self, x):
|
|
a1, a2 = x["a"]
|
|
b = x["b"]
|
|
a1_conv = self.conv(a1)
|
|
a1_linear = self.linear(a1_conv)
|
|
a2_conv = self.conv(a2)
|
|
a2_linear = self.linear(a2_conv)
|
|
b_conv = self.conv(b)
|
|
b_linear = self.linear(b_conv)
|
|
return [
|
|
[
|
|
a1_linear.cos() + b_linear.sin(),
|
|
a1_linear.cos() + b_linear.sin(),
|
|
],
|
|
[
|
|
a2_linear.sin() + b_linear.cos(),
|
|
a2_linear.sin() + b_linear.cos(),
|
|
],
|
|
]
|
|
|
|
inp_container = {
|
|
"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
|
|
"b": torch.randn(20, 16, 50, 100),
|
|
}
|
|
|
|
gm, _ = torch._dynamo.export(Foo(), inp_container, aten_graph=True)
|
|
gm2, _ = torch._dynamo.export(gm, inp_container, aten_graph=True)
|
|
|
|
inp_test = {
|
|
"a": (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100)),
|
|
"b": torch.randn(20, 16, 50, 100),
|
|
}
|
|
|
|
self.assertTrue(torch.allclose(gm(inp_test)[0][0], gm2(inp_test)[0][0]))
|
|
self.assertTrue(torch.allclose(gm(inp_test)[0][1], gm2(inp_test)[0][1]))
|
|
self.assertTrue(torch.allclose(gm(inp_test)[1][0], gm2(inp_test)[1][0]))
|
|
self.assertTrue(torch.allclose(gm(inp_test)[1][1], gm2(inp_test)[1][1]))
|
|
|
|
def test_fx_pytree(self):
|
|
def foo(args):
|
|
flat_args, spec = torch.utils._pytree.tree_flatten(args)
|
|
flat_args_fx = torch.fx._pytree.tree_flatten_spec(args, spec)
|
|
return flat_args_fx[0] + flat_args[0]
|
|
|
|
inp_container = (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50, 100))
|
|
|
|
gm, _ = torch._dynamo.export(foo, inp_container, aten_graph=True)
|
|
|
|
self.assertTrue(torch.allclose(foo(inp_container), gm(inp_container)))
|
|
|
|
@config.patch(suppress_errors=True)
|
|
@config.patch(verbose=True)
|
|
def test_export_with_map_zero_sized_tensor_suppress_errors(self):
|
|
from functorch.experimental.control_flow import map
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, xs):
|
|
def body(x):
|
|
return x + 1
|
|
|
|
return map(body, xs)
|
|
|
|
mod = Module()
|
|
xs = torch.randn(0, 2)
|
|
with self.assertRaises(
|
|
torch._dynamo.exc.Unsupported,
|
|
):
|
|
torch._dynamo.export(mod, xs)
|
|
|
|
def test_param_buffer_safe_from_mutation_simple(self):
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer1 = torch.nn.Buffer(torch.zeros(5, 5))
|
|
|
|
def forward(self, x):
|
|
self.buffer1.add_(1)
|
|
return x + self.buffer1
|
|
|
|
gm, _ = torch._dynamo.export(Module(), torch.ones(5, 5), aten_graph=False)
|
|
buffers = list(gm.named_buffers())
|
|
self.assertEqual(len(buffers), 1)
|
|
|
|
name, buffer = buffers[0]
|
|
self.assertEqual(name, "L__self___buffer1")
|
|
|
|
self.assertTrue(torch.allclose(buffer, torch.zeros(5)))
|
|
|
|
def test_param_buffer_safe_from_mutation_recurse(self):
|
|
class Child(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer2 = torch.nn.Buffer(torch.zeros(5))
|
|
|
|
def forward(self, x):
|
|
return x.sum() + self.buffer2.sum()
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.buffer1 = torch.nn.Buffer(torch.zeros(5))
|
|
self.child = Child()
|
|
|
|
def forward(self, x):
|
|
self.buffer1.add_(1)
|
|
self.child.buffer2.add_(2)
|
|
return x.sum() + self.buffer1.sum() + self.child(x)
|
|
|
|
gm, _ = torch._dynamo.export(Module(), torch.ones(5), aten_graph=False)
|
|
for _, buffer in gm.named_buffers():
|
|
self.assertTrue(torch.allclose(buffer, torch.zeros(5)))
|
|
|
|
def test_predispatch_with_higher_order(self):
|
|
def f(x):
|
|
return cond(x.shape[0] > 4, lambda x: x + 5, lambda x: x - 3, [x])
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)(
|
|
torch.randn(4, 4)
|
|
)
|
|
inp1 = torch.randn(4, 4)
|
|
inp2 = torch.randn(6, 4)
|
|
self.assertTrue(torch.allclose(f(inp1), gm(inp1)))
|
|
self.assertTrue(torch.allclose(f(inp2), gm(inp2)))
|
|
|
|
def test_predispatch_with_higher_order_nested(self):
|
|
def f(x):
|
|
def true_fn(x):
|
|
return cond(x.shape[0] > 6, lambda x: x + 10, lambda x: x - 10, [x])
|
|
|
|
return cond(x.shape[0] > 4, true_fn, lambda x: x - 3, [x])
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True, pre_dispatch=True)(
|
|
torch.randn(4, 4)
|
|
)
|
|
inp1 = torch.randn(4, 4)
|
|
inp2 = torch.randn(6, 4)
|
|
inp3 = torch.randn(8, 4)
|
|
self.assertTrue(torch.allclose(f(inp1), gm(inp1)))
|
|
self.assertTrue(torch.allclose(f(inp2), gm(inp2)))
|
|
self.assertTrue(torch.allclose(f(inp3), gm(inp3)))
|
|
|
|
def test_predispatch_with_for_out_dtype(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = weight
|
|
|
|
def forward(self, x):
|
|
return out_dtype(torch.ops.aten.mm.default, torch.int32, x, self.weight)
|
|
|
|
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
|
m = M(weight)
|
|
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
|
gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True)
|
|
|
|
self.assertTrue(torch.allclose(m(x), gm(x)))
|
|
|
|
def test_predispatch_with_for_out_dtype_nested(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.weight = weight
|
|
|
|
def true_fn(self, x):
|
|
return out_dtype(
|
|
torch.ops.aten.mm.default, torch.int32, x, self.weight
|
|
).sum()
|
|
|
|
def false_fn(self, x):
|
|
return out_dtype(
|
|
torch.ops.aten.mul.Tensor, torch.int32, x, self.weight
|
|
).sum()
|
|
|
|
def forward(self, x):
|
|
return cond(x.sum() != 0, self.true_fn, self.false_fn, [x])
|
|
|
|
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
|
m = M(weight)
|
|
x = torch.ones((5, 5), dtype=torch.int8)
|
|
gm, _ = torch._dynamo.export(m, x, aten_graph=True, pre_dispatch=True)
|
|
|
|
self.assertTrue(torch.allclose(m(x), gm(x)))
|
|
y = torch.zeros((5, 5), dtype=torch.int8)
|
|
self.assertTrue(torch.allclose(m(y), gm(y)))
|
|
|
|
self.assertExpectedInline(
|
|
gm.true_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mm.default, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None
|
|
sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None
|
|
return (sum_1,)""",
|
|
)
|
|
|
|
self.assertExpectedInline(
|
|
gm.false_graph_0.code.strip(),
|
|
"""\
|
|
def forward(self, arg0_1, arg1_1):
|
|
out_dtype = torch.ops.higher_order.out_dtype(torch.ops.aten.mul.Tensor, torch.int32, arg1_1, arg0_1); arg1_1 = arg0_1 = None
|
|
sum_1 = torch.ops.aten.sum.default(out_dtype); out_dtype = None
|
|
return (sum_1,)""",
|
|
)
|
|
|
|
def test_export_nn_module_stack_patched_module(self):
|
|
def forward(self, x, y):
|
|
return x * y
|
|
|
|
class Toplevel(torch.nn.Module):
|
|
def __init__(self, m):
|
|
super().__init__()
|
|
self.m = m
|
|
|
|
def forward(self, x, y):
|
|
return self.m(x, y)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
t = Toplevel(M())
|
|
t.m.forward = forward.__get__(t.m, M)
|
|
x, y = torch.rand(3), torch.rand(3)
|
|
gm, _ = torch._dynamo.export(t, x, y)
|
|
|
|
self.assertTrue(torch.allclose(forward(None, x, y), gm(x, y)))
|
|
for node in gm.graph.nodes:
|
|
if node.op == "call_function":
|
|
self.assertIn("nn_module_stack", node.meta)
|
|
|
|
def test_preserve_fx_node_metadata(self):
|
|
class Module1(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.sin(x)
|
|
|
|
class Module2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mod1 = Module1()
|
|
|
|
def forward(self, x):
|
|
x = torch.cos(x)
|
|
x = self.mod1(x)
|
|
x = torch.relu(x)
|
|
return x
|
|
|
|
def fn(x):
|
|
return torch.abs(x)
|
|
|
|
mod = Module2()
|
|
inp = torch.randn(3, 3)
|
|
|
|
gm, _ = torch._dynamo.export(mod)(inp)
|
|
|
|
# replace relu with fn
|
|
gm_edit = copy.deepcopy(gm)
|
|
for nd in gm_edit.graph.nodes:
|
|
if nd.target == torch.relu:
|
|
nd.target = fn
|
|
nd.meta.clear()
|
|
break
|
|
gm_edit.recompile()
|
|
|
|
gm2, _ = torch._dynamo.export(gm_edit)(inp)
|
|
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
x = torch.cos(l_x_); l_x_ = None
|
|
x_1 = torch.sin(x); x = None
|
|
x_2 = torch.relu(x_1); x_1 = None
|
|
return pytree.tree_unflatten([x_2], self._out_spec)""",
|
|
)
|
|
|
|
def _constais_op(gm, target):
|
|
for nd in gm.graph.nodes:
|
|
if nd.target == target:
|
|
return True
|
|
return False
|
|
|
|
self.assertTrue(_constais_op(gm_edit, torch.cos))
|
|
self.assertTrue(_constais_op(gm_edit, torch.sin))
|
|
self.assertTrue(not _constais_op(gm_edit, torch.relu))
|
|
|
|
self.assertExpectedInline(
|
|
gm2.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
x = torch.cos(l_x_); l_x_ = None
|
|
x_1 = torch.sin(x); x = None
|
|
x_2 = torch.abs(x_1); x_1 = None
|
|
return pytree.tree_unflatten([x_2], self._out_spec)""",
|
|
)
|
|
|
|
# check for other metadata
|
|
for op in (torch.sin, torch.cos):
|
|
nd1 = next(filter(lambda nd: nd.target == op, gm.graph.nodes))
|
|
nd2 = next(filter(lambda nd: nd.target == op, gm2.graph.nodes))
|
|
self.assertTrue(
|
|
("nn_module_stack" in nd1.meta) == ("nn_module_stack" in nd2.meta)
|
|
)
|
|
if "nn_module_stack" in nd1.meta:
|
|
self.assertEqual(
|
|
nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"]
|
|
)
|
|
self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"])
|
|
|
|
def test_preserve_fx_node_metadata_recompile(self):
|
|
def fn(x):
|
|
return torch.sin(x)
|
|
|
|
gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
|
|
do_export = torch._dynamo.export(gm)
|
|
torch.compile(fn, backend="eager")(torch.randn(3, 3))
|
|
gm1, _ = do_export(torch.randn(3, 3))
|
|
gm2, _ = do_export(torch.randn(5, 3))
|
|
|
|
self.assertExpectedInline(
|
|
gm1.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
sin = torch.sin(l_x_); l_x_ = None
|
|
return pytree.tree_unflatten([sin], self._out_spec)""",
|
|
)
|
|
self.assertExpectedInline(
|
|
gm2.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
sin = torch.sin(l_x_); l_x_ = None
|
|
return pytree.tree_unflatten([sin], self._out_spec)""",
|
|
)
|
|
|
|
def test_preserve_fx_node_metadata_inline(self):
|
|
def f1(x):
|
|
return torch.sin(x)
|
|
|
|
gm, _ = torch._dynamo.export(f1)(torch.randn(3, 3))
|
|
|
|
def f2(x):
|
|
x = torch.cos(x)
|
|
return gm(x)
|
|
|
|
gm2, _ = torch._dynamo.export(f2)(torch.randn(3, 3))
|
|
|
|
self.assertExpectedInline(
|
|
gm2.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
x = torch.cos(l_x_); l_x_ = None
|
|
sin = torch.sin(x); x = None
|
|
return pytree.tree_unflatten([sin], self._out_spec)""",
|
|
)
|
|
|
|
def test_preserve_fx_node_metadata_graph_break(self):
|
|
def fn(x):
|
|
x = torch.sin(x)
|
|
x = torch.abs(x)
|
|
return torch.cos(x)
|
|
|
|
def bad_fn(x):
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
gm, _ = torch._dynamo.export(fn)(torch.randn(3, 3))
|
|
|
|
# replace abs with graph break
|
|
gm_edit = copy.deepcopy(gm)
|
|
for nd in gm_edit.graph.nodes:
|
|
if nd.target == torch.abs:
|
|
nd.target = bad_fn
|
|
nd.meta.clear()
|
|
break
|
|
gm_edit.recompile()
|
|
|
|
expected = [
|
|
"""x = torch.sin(l_x_)""",
|
|
"""cos = torch.cos(l_stack0_)""",
|
|
]
|
|
|
|
def test_backend(gm: torch.fx.GraphModule, example_inputs):
|
|
self.assertTrue(expected)
|
|
# Normalize output for dynamic and not
|
|
for nd in gm.graph.nodes:
|
|
if "example_value" in nd.meta:
|
|
del nd.meta["example_value"]
|
|
self.assertIn(expected[0], gm.print_readable(print_output=False))
|
|
expected.pop(0)
|
|
return gm.forward
|
|
|
|
torch._dynamo.reset()
|
|
opt_gm_edit = torch.compile(gm_edit, backend=test_backend)
|
|
opt_gm_edit(torch.randn(3, 3))
|
|
|
|
def test_torch_inference_mode_ctx(self):
|
|
@torch.inference_mode()
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
gm, _ = torch._dynamo.export(fn, torch.rand(2, 2))
|
|
|
|
inp = torch.randn(2, 2)
|
|
out = gm(inp)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_args_0_ = arg0
|
|
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
|
|
add = l_args_0_ + 1; l_args_0_ = None
|
|
_exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None
|
|
return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950
|
|
)
|
|
self.assertEqual(out.requires_grad, False)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.",
|
|
):
|
|
out.requires_grad = True
|
|
|
|
@torch.inference_mode(False)
|
|
def fn_no_inference(x):
|
|
return x + 1
|
|
|
|
gm_no_inference, _ = torch._dynamo.export(fn_no_inference, torch.rand(2, 2))
|
|
self.assertExpectedInline(
|
|
gm_no_inference.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_args_0_ = arg0
|
|
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False)
|
|
add = l_args_0_ + 1; l_args_0_ = None
|
|
_exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None
|
|
return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950
|
|
)
|
|
|
|
inp = torch.randn(2, 2)
|
|
out = gm_no_inference(inp)
|
|
self.assertEqual(out.requires_grad, False)
|
|
out.requires_grad = True
|
|
|
|
def fn(x):
|
|
with torch.inference_mode():
|
|
return x + 1
|
|
|
|
gm, _ = torch._dynamo.export(fn)(torch.rand(2, 2))
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
|
|
add = l_x_ + 1; l_x_ = None
|
|
_exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None
|
|
return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950
|
|
)
|
|
inp = torch.randn(2, 2, requires_grad=True)
|
|
out = gm(inp)
|
|
self.assertEqual(out.requires_grad, False)
|
|
|
|
def test_export_masking_with_no_grad(self):
|
|
def fn(x, b, y):
|
|
x = x.clone()
|
|
x[b] = y
|
|
return x
|
|
|
|
def fn_no_grad(x, b, y):
|
|
with torch.no_grad():
|
|
return fn(x, b, y)
|
|
|
|
def fn_inference_mode(x, b, y):
|
|
with torch.inference_mode():
|
|
return fn(x, b, y)
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
b = torch.tensor([True, False, True, False])
|
|
y = torch.randn(2, requires_grad=True)
|
|
|
|
gm, _ = torch._dynamo.export(fn_no_grad)(x, b, y)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x, b, y):
|
|
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
l_b_ = arg1
|
|
l_y_ = arg2
|
|
_set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None
|
|
x = l_x_.clone(); l_x_ = None
|
|
x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None
|
|
_set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None
|
|
return pytree.tree_unflatten([x], self._out_spec)""",
|
|
)
|
|
|
|
gm, _ = torch._dynamo.export(fn_inference_mode)(x, b, y)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x, b, y):
|
|
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
|
|
l_x_ = arg0
|
|
l_b_ = arg1
|
|
l_y_ = arg2
|
|
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
|
|
x = l_x_.clone(); l_x_ = None
|
|
x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None
|
|
_exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None
|
|
return pytree.tree_unflatten([x], self._out_spec)""", # NOQA: B950
|
|
)
|
|
|
|
gm, _ = torch._dynamo.export(fn)(x, b, y)
|
|
|
|
def test_dynamo_list_index(self):
|
|
def fn(x, in_list):
|
|
return x + in_list.index(2)
|
|
|
|
inputs = (torch.ones(2, 2), [1, 2])
|
|
graph, _ = torch._dynamo.export(fn)(*inputs)
|
|
out = graph(*inputs)
|
|
self.assertEqual(out, torch.ones(2, 2) + 1)
|
|
|
|
def test_dynamo_enum_in_tuple(self):
|
|
class IntEnum(int, Enum):
|
|
X = 0
|
|
|
|
def fn(tensor):
|
|
return tensor[..., IntEnum.X]
|
|
|
|
tensor = torch.rand((5, 5))
|
|
graph, _ = torch._dynamo.export(fn)(tensor)
|
|
out = graph(tensor)
|
|
self.assertEqual(out, tensor[:, 0])
|
|
|
|
def test_subclass_parameters(self):
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.p1 = torch.nn.Parameter(torch.ones(3, 4))
|
|
self.p2 = torch.nn.Parameter(
|
|
TwoTensor(torch.zeros(3, 4), torch.zeros(3, 4))
|
|
)
|
|
|
|
def forward(self, x):
|
|
return x + 2 * self.p1 + self.p2
|
|
|
|
m = M()
|
|
ref_x = torch.randn(3, 4)
|
|
ref_out = m(ref_x)
|
|
|
|
from torch._functorch._aot_autograd.subclass_parametrization import (
|
|
unwrap_tensor_subclass_parameters,
|
|
)
|
|
|
|
unwrap_tensor_subclass_parameters(m)
|
|
ref_x2 = ref_x.detach().clone()
|
|
ref_out2 = m(ref_x2)
|
|
self.assertEqual(ref_out2, ref_out)
|
|
|
|
x = ref_x.detach().clone()
|
|
graph, _ = torch._dynamo.export(m)(x)
|
|
out = graph(x)
|
|
self.assertEqual(ref_out, out)
|
|
|
|
|
|
class ExportTestsDevice(torch._dynamo.test_case.TestCase):
|
|
def test_export_with_parameters(self, device):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.features = torch.nn.Sequential(
|
|
torch.nn.Conv2d(
|
|
3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
|
|
),
|
|
torch.nn.ReLU(inplace=True),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.features(x)
|
|
|
|
model = MyModule().eval().to(device)
|
|
random_inputs = (torch.rand([32, 3, 32, 32]).to(device),)
|
|
dim_x = torch.export.Dim("dim_x", min=1, max=32)
|
|
exp_program = torch.export.export(
|
|
model, random_inputs, dynamic_shapes={"x": {0: dim_x}}, strict=True
|
|
)
|
|
output_buffer = io.BytesIO()
|
|
# Tests if we can restore saved nn.Parameters when we load them again
|
|
torch.export.save(exp_program, output_buffer)
|
|
loaded_model = torch.export.load(output_buffer)
|
|
self.assertTrue(
|
|
isinstance(
|
|
loaded_model.module().get_parameter("features.0.weight"),
|
|
torch.nn.Parameter,
|
|
)
|
|
)
|
|
|
|
def test_export_fast_binary_broadcast_check(self, device):
|
|
# This test looks at the case where we erroneously create a guard
|
|
# when checking the equality of the operands' shape and the output
|
|
# shape during FakeTensor's binary op fast path.
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
# final shape is (dim0, 4, 8)
|
|
# order matters since a & the output have the same shape
|
|
return b + a
|
|
|
|
a = torch.randn(100, 4, 8)
|
|
b = torch.randn(4, 8)
|
|
model = MyModel().eval().to(device)
|
|
batchsize = torch.export.Dim("dim0", min=3, max=1024)
|
|
dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]}
|
|
|
|
torch.export.export(
|
|
model, (a, b), dynamic_shapes=dynamic_shape_spec, strict=True
|
|
)
|
|
|
|
def test_export_fast_binary_broadcast_check_unbacked(self, device):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, numel, scalar):
|
|
u0 = numel.item()
|
|
torch._check_is_size(u0)
|
|
x = torch.ones(u0 + 1)
|
|
return scalar - x
|
|
|
|
model = MyModel().eval().to(device)
|
|
numel = torch.tensor(10)
|
|
scalar = torch.randn(1)
|
|
torch.export.export(model, (numel, scalar), strict=True)
|
|
|
|
|
|
common_utils.instantiate_parametrized_tests(ExportTests)
|
|
devices = ["cuda", "hpu"]
|
|
instantiate_device_type_tests(ExportTestsDevice, globals(), only_for=devices)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|