mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 1e738420296a84406cd0a1626074ea6447a6603a. Reverted https://github.com/pytorch/pytorch/pull/137726 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but it looks like some internal components are failing after this change and need to be updated ([comment](https://github.com/pytorch/pytorch/pull/137726#issuecomment-2455332612))
4662 lines
166 KiB
Python
4662 lines
166 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
import builtins
|
|
import contextlib
|
|
import copy
|
|
import functools
|
|
import inspect
|
|
import math
|
|
import numbers
|
|
import io
|
|
import operator
|
|
import os
|
|
import pickle
|
|
import sys
|
|
import torch
|
|
import traceback
|
|
import typing
|
|
import types
|
|
import warnings
|
|
import unittest
|
|
from math import sqrt
|
|
from functorch.experimental import control_flow
|
|
from torch.multiprocessing import Process
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
|
|
import torch.utils._pytree as pytree
|
|
import torch.fx._pytree as fx_pytree
|
|
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH, CodeGen
|
|
from torch.fx.node import Target, Argument, _format_arg
|
|
from torch.fx.passes import shape_prop
|
|
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
|
from torch.fx.experimental.rewriter import RewritingTracer
|
|
from torch.fx.operator_schemas import get_signature_for_torch_op
|
|
from copy import deepcopy
|
|
from collections import namedtuple
|
|
|
|
from torch.fx.proxy import TraceError
|
|
from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMPATIBILITY
|
|
from torch.fx._symbolic_trace import PHBase, PHWithMeta
|
|
from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401
|
|
from fx.test_dce_pass import TestDCE # noqa: F401
|
|
from fx.test_fx_const_fold import TestConstFold # noqa: F401
|
|
from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow # noqa: F401
|
|
from fx.test_pass_infra import TestPassManager # noqa: F401
|
|
from fx.test_common_passes import TestCommonPass # noqa: F401
|
|
from fx.test_cse_pass import TestCSEPass # noqa: F401
|
|
from fx.test_matcher_utils import TestMatcher # noqa: F401
|
|
from fx.test_source_matcher_utils import TestSourceMatcher # noqa: F401
|
|
|
|
from fx.test_gradual_type import AnnotationsTest # noqa: F401
|
|
from fx.test_gradual_type import TypeCheckerTest # noqa: F401
|
|
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Set, Tuple, Union
|
|
from torch.testing._internal.common_utils import (
|
|
IS_FBCODE,
|
|
IS_MACOS,
|
|
IS_WINDOWS,
|
|
find_library_location,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
from fx.named_tup import MyNamedTup
|
|
|
|
try:
|
|
from torchvision import models as torchvision_models
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
|
|
|
|
class SimpleTest(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.relu(x + 3.0)
|
|
|
|
def a_non_torch_leaf(a, b):
|
|
return a + b
|
|
|
|
# Used for test_autowrap_function. Autowrapped functions need to be global
|
|
def fx_int(x: float) -> int:
|
|
return int(x)
|
|
|
|
def fx_int_x2(x: float) -> int:
|
|
return int(x) * 2
|
|
|
|
# used in test_pytree. It's all the way out here because pickling a GraphModule
|
|
# that uses Point errors out if Point is local to the function
|
|
Point = namedtuple('Point', ['x', 'y'])
|
|
|
|
# Test wrap() passing both a function name as well as a function
|
|
# directly
|
|
def a_lifted_leaf(a, b):
|
|
return a[0] + a[1] + b
|
|
|
|
wrap('a_lifted_leaf')
|
|
# Test wrapping twice doesn't break anything
|
|
wrap('a_lifted_leaf')
|
|
|
|
def a_lifted_leaf2(a, b):
|
|
return a[0] + a[1] + b
|
|
|
|
wrap(a_lifted_leaf2)
|
|
|
|
wrap('len')
|
|
|
|
wrap('getattr')
|
|
|
|
def wrapped_named_tup(p1, *, p2):
|
|
return p1.x + p2.y
|
|
|
|
wrap(wrapped_named_tup)
|
|
|
|
@wrap
|
|
def wrapped_via_decorator(a):
|
|
return a + 1
|
|
|
|
wrap('wrapped_with_submodule')
|
|
|
|
def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
|
|
return batchnorm1d(x)
|
|
|
|
def my_decorator(f):
|
|
@functools.wraps(f)
|
|
def wrapper_inside_decorator(*args, **kwargs):
|
|
return f(*args, **kwargs)
|
|
return wrapper_inside_decorator
|
|
|
|
@wrap
|
|
@my_decorator
|
|
def wrapped_decorated_fn(x):
|
|
return x
|
|
|
|
real_wrapped_via_decorator = wrapped_via_decorator
|
|
real_a_lifed_leaf = a_lifted_leaf
|
|
real_a_lifed_leaf2 = a_lifted_leaf2
|
|
_sqrt = sqrt
|
|
|
|
wrap('wrapper_fn')
|
|
|
|
def wrapper_fn(x):
|
|
return torch.foo(x)
|
|
|
|
class Pair(NamedTuple):
|
|
x : torch.Tensor
|
|
y : torch.Tensor
|
|
|
|
def _custom_fx_repr_fn(self) -> str:
|
|
return f"Pair(x={_format_arg(self.x)}, y={_format_arg(self.y)})"
|
|
|
|
# for testing pytrees
|
|
class Foo: # noqa: B209
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
class Add(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + x
|
|
|
|
@torch.fx.has_side_effect
|
|
@torch.fx.wrap
|
|
def side_effect_func(x: torch.Tensor):
|
|
print(x)
|
|
|
|
class TestFX(JitTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
# Checking for mutable operations whil tracing is feature flagged
|
|
# Enable it in testing but not by default
|
|
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = True
|
|
|
|
if not (IS_FBCODE or IS_WINDOWS or IS_MACOS):
|
|
lib_file_path = find_library_location('libtorchbind_test.so')
|
|
torch.ops.load_library(str(lib_file_path))
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
|
|
|
|
def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
|
|
"""Check that an nn.Module's results match the GraphModule version
|
|
for a given set of args/kwargs.
|
|
"""
|
|
kwargs = kwargs if kwargs else {}
|
|
ref_outs = m(*args, **kwargs)
|
|
gm = symbolic_trace(m)
|
|
gm.graph.lint()
|
|
test_outs = gm(*args, **kwargs)
|
|
self.assertEqual(ref_outs, test_outs)
|
|
|
|
def test_graph_module(self):
|
|
class MySub(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.w = torch.nn.Parameter(torch.rand(4, 3))
|
|
|
|
def forward(self, x):
|
|
return self.w + x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.lin = torch.nn.Linear(4, 3)
|
|
self.sub_mod = MySub()
|
|
self.w = torch.nn.Parameter(torch.rand(3))
|
|
|
|
def forward(self, A, B, c):
|
|
t = torch.sigmoid(A) + self.lin(c)
|
|
return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
|
|
|
|
m = MyModule()
|
|
gm = symbolic_trace(m)
|
|
|
|
ms = torch.jit.script(gm)
|
|
|
|
class M2(torch.nn.Module):
|
|
def forward(self, A):
|
|
m, idx = torch.max(A, 0)
|
|
return m + 1, idx + 1
|
|
|
|
m2 = M2()
|
|
gm2 = symbolic_trace(m2)
|
|
|
|
class T(torch.nn.Module):
|
|
|
|
def forward(self, A, b=4, *args, c=5, **kwargs):
|
|
x = A + 1 + args[0] + kwargs['3']
|
|
return x
|
|
|
|
t = T()
|
|
symbolic_trace(t)
|
|
|
|
# test for issue described at https://github.com/pytorch/pytorch/issues/63883
|
|
class M3(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.relu(x)
|
|
|
|
m3 = M3()
|
|
gm3 = symbolic_trace(m3)
|
|
new_instance = gm3.__new__(type(gm3))
|
|
new_instance.__init__(gm3, gm3.graph)
|
|
|
|
x = torch.randn(5, 3)
|
|
torch.testing.assert_close(new_instance(x), torch.relu(x))
|
|
|
|
def test_informative_co_filename(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, a):
|
|
return a * 2
|
|
|
|
gm = symbolic_trace(MyModule())
|
|
self.assertIn(os.path.basename(__file__), gm.forward.__code__.co_filename)
|
|
|
|
def test_custom_import(self):
|
|
graph = torch.fx.Graph()
|
|
a = graph.placeholder('x')
|
|
b = graph.placeholder('y')
|
|
c = graph.call_function(a_non_torch_leaf, (a, b))
|
|
d = graph.call_function(torch.sin, (c,))
|
|
graph.output(d)
|
|
gm = GraphModule(torch.nn.Module(), graph)
|
|
x, y = torch.rand(1), torch.rand(1)
|
|
self.assertEqual(torch.sin(x + y), gm(x, y))
|
|
|
|
def test_args_kwargs(self):
|
|
class T(torch.nn.Module):
|
|
def forward(self, *args, **kwargs):
|
|
x = args[0] + kwargs['foo']
|
|
return x
|
|
|
|
t = T()
|
|
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
|
|
|
|
def test_varargs_concrete(self):
|
|
class T(torch.nn.Module):
|
|
def forward(self, *args, **kwargs):
|
|
x = args[0] + args[1]
|
|
return x
|
|
|
|
args = (torch.rand(1), torch.rand(1))
|
|
|
|
t = T()
|
|
ref_outs = t(*args)
|
|
gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
|
|
gm.graph.lint()
|
|
test_outs = gm(*args)
|
|
self.assertEqual(ref_outs, test_outs)
|
|
|
|
def test_args_kwargs_no_self(self):
|
|
class T(torch.nn.Module):
|
|
def forward(*args, **kwargs): # noqa: B902
|
|
self = args[0]
|
|
return torch.relu(args[1])
|
|
|
|
t = T()
|
|
with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
|
|
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
|
|
|
|
def test_fx_shifts(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x << 3, x >> 3
|
|
|
|
input = torch.LongTensor(10).random_(0, 1024)
|
|
|
|
m = MyModule()
|
|
self.checkGraphModule(m, (input,))
|
|
|
|
def test_fx_and_or(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x & x, x | x
|
|
|
|
input = torch.LongTensor(10).random_(0, 1024)
|
|
|
|
m = MyModule()
|
|
self.checkGraphModule(m, (input,))
|
|
|
|
def test_dict(self):
|
|
class MyDictMod(torch.nn.Module):
|
|
def forward(self, d):
|
|
return d['3'].relu(), {'4' : d['3'].neg()}
|
|
|
|
input_dict = {'3': torch.rand(3, 4)}
|
|
m = MyDictMod()
|
|
|
|
self.checkGraphModule(m, (input_dict,))
|
|
|
|
def test_matmul_tracing(self):
|
|
const = torch.randn(3)
|
|
|
|
def matmul_f(x):
|
|
return x @ const
|
|
|
|
mod = symbolic_trace(matmul_f)
|
|
inp = torch.randn(3)
|
|
self.assertEqual(mod(inp), matmul_f(inp))
|
|
|
|
def rmatmul_f(x):
|
|
return const @ x
|
|
|
|
mod = symbolic_trace(rmatmul_f)
|
|
inp = torch.randn(3)
|
|
self.assertEqual(mod(inp), rmatmul_f(inp))
|
|
|
|
@skipIfNoDynamoSupport
|
|
def test_control_flow_tracing(self):
|
|
def true(x, y):
|
|
return x + y
|
|
|
|
def false(x, y):
|
|
return x - y
|
|
|
|
def f(x, y):
|
|
x = control_flow.cond(x[0] == 0, true, false, [x, y])
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Expected pred to be bool or tensor, but got Proxy\(eq\)"):
|
|
_ = symbolic_trace(f)
|
|
|
|
def test_disallow_override(self):
|
|
# Custom delegate to disallow in-place tensor operations
|
|
class NoMutableCallTracer(Tracer):
|
|
def create_node(self, kind : str, target : Union[str, Callable],
|
|
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
|
|
type_expr : Optional[Any] = None) -> Node:
|
|
name = target if isinstance(target, str) else torch.typename(target)
|
|
if name[-1] == '_':
|
|
raise RuntimeError('In-place operations are not supported')
|
|
return super().create_node(kind, target, args, kwargs, name)
|
|
|
|
# Test method
|
|
class MyInplaceMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
x.add_(3.0)
|
|
return x
|
|
|
|
m = MyInplaceMod()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
|
|
NoMutableCallTracer().trace(m)
|
|
|
|
# Test free function
|
|
class MyInplaceMod2(torch.nn.Module):
|
|
def forward(self, x):
|
|
torch.log_(x)
|
|
return x
|
|
m2 = MyInplaceMod2()
|
|
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
|
|
NoMutableCallTracer().trace(m2)
|
|
|
|
# Test symbolic node as an arg
|
|
class MyInplaceMod3(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.ones(3, 4)
|
|
y.add_(x)
|
|
return x
|
|
m3 = MyInplaceMod3()
|
|
with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
|
|
NoMutableCallTracer().trace(m3)
|
|
|
|
def test_leaf_module(self):
|
|
# Custom delegate to make it so that there are no leaf modules, everything
|
|
# should get traced through
|
|
class NoLeafModulesTracer(Tracer):
|
|
def is_leaf_module(self, m, qualname):
|
|
return False
|
|
|
|
class MyReluMod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(x)
|
|
|
|
mrm = MyReluMod()
|
|
sym = NoLeafModulesTracer().trace(mrm)
|
|
for node in sym.nodes:
|
|
self.assertNotEqual(node.op, 'call_module')
|
|
sym.lint()
|
|
|
|
def test_wrap(self):
|
|
self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
|
|
|
|
def to_trace(y):
|
|
return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
|
|
|
|
m = symbolic_trace(to_trace)
|
|
self.assertIn('a_lifted_leaf', m.code)
|
|
self.assertEqual(27, m(2))
|
|
self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
|
|
|
|
def test_wrap_fn_directly(self):
|
|
self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
|
|
|
|
def to_trace(y):
|
|
return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
|
|
|
|
m = symbolic_trace(to_trace)
|
|
self.assertIn('a_lifted_leaf2', m.code)
|
|
self.assertEqual(27, m(2))
|
|
self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
|
|
|
|
def test_wrapped_via_decorator(self):
|
|
self.assertEqual(wrapped_via_decorator(0), 1)
|
|
|
|
def to_trace(y):
|
|
return wrapped_via_decorator(y)
|
|
|
|
m = symbolic_trace(to_trace)
|
|
self.assertIn('wrapped_via_decorator', m.code)
|
|
self.assertEqual(m(0), 1)
|
|
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
|
|
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
|
|
|
|
def test_wrapped_via_decorator_and_transformed(self):
|
|
self.assertEqual(wrapped_via_decorator(0), 1)
|
|
|
|
def to_trace(y):
|
|
return wrapped_via_decorator(y)
|
|
|
|
m = symbolic_trace(to_trace)
|
|
self.assertIn('wrapped_via_decorator', m.code)
|
|
self.assertEqual(m(0), 1)
|
|
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
|
|
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
|
|
|
|
transformed = torch.fx.Transformer(m).transform()
|
|
self.assertIn('wrapped_via_decorator', transformed.code)
|
|
self.assertEqual(transformed(0), 1)
|
|
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
|
|
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
|
|
|
|
def test_wrap_with_submodule(self):
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return wrapped_with_submodule(x, self.batchnorm1d)
|
|
|
|
m = symbolic_trace(M())
|
|
|
|
self.assertIn("wrapped_with_submodule", m.code)
|
|
|
|
input = torch.rand(3, 2)
|
|
ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
|
|
self.assertEqual(ref_batchnorm1d(input), m(input))
|
|
|
|
def test_wrapped_retrace(self):
|
|
def to_trace(y):
|
|
return wrapped_via_decorator(y)
|
|
|
|
m = symbolic_trace(to_trace)
|
|
self.assertIn('wrapped_via_decorator', m.code)
|
|
self.assertEqual(m(0), 1)
|
|
|
|
retraced = symbolic_trace(m)
|
|
self.assertIn('wrapped_via_decorator', retraced.code)
|
|
self.assertEqual(retraced(0), 1)
|
|
|
|
def test_wrap_decorated_function(self):
|
|
def to_trace(y):
|
|
return wrapped_decorated_fn(y)
|
|
|
|
m = symbolic_trace(to_trace)
|
|
self.assertIn('wrapped_decorated_fn', m.code)
|
|
self.assertEqual(m(1), 1)
|
|
|
|
def test_graph_edit_with_proxy(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
m = M()
|
|
g = symbolic_trace(m).graph
|
|
new_g = torch.fx.Graph()
|
|
val_map : Dict[Node, Node] = {}
|
|
output_val = new_g.graph_copy(g, val_map)
|
|
t = Proxy(output_val)
|
|
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
|
|
new_g.output((t + t).node)
|
|
gm = GraphModule(m, new_g)
|
|
gm.graph.lint()
|
|
self.assertEqual(gm(3, 4), 14)
|
|
|
|
def test_proxy_deepcopy_without_tracer(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return 2 * x
|
|
|
|
module = MyModule()
|
|
traced = symbolic_trace(module)
|
|
node = list(traced.graph.nodes)[-2]
|
|
p = torch.fx.Proxy(node, None)
|
|
node.proxy = p
|
|
p2 = copy.deepcopy(p)
|
|
self.assertTrue(isinstance(p2, torch.fx.Proxy))
|
|
self.assertEqual(p2.node.name, node.name)
|
|
self.assertEqual(p2.node.target, node.target)
|
|
self.assertNotEqual(id(p2.node), id(node))
|
|
|
|
def test_proxy_deepcopy_with_tracer(self):
|
|
class TestTracer(Tracer):
|
|
def __init__(self, name):
|
|
super().__init__()
|
|
self.name = name
|
|
|
|
def is_leaf_module(self, module, name):
|
|
return True
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return 2 * x
|
|
|
|
module = MyModule()
|
|
tracer = TestTracer("mytracer")
|
|
traced = symbolic_trace(module)
|
|
node = list(traced.graph.nodes)[-2]
|
|
p = torch.fx.Proxy(node, tracer)
|
|
node.proxy = p
|
|
p2 = copy.deepcopy(p)
|
|
self.assertTrue(isinstance(p2, torch.fx.Proxy))
|
|
self.assertTrue(isinstance(p2.tracer, torch.fx._symbolic_trace.Tracer))
|
|
self.assertEqual(p2.tracer.name, "mytracer")
|
|
self.assertEqual(p2.node.name, node.name)
|
|
self.assertEqual(p2.node.target, node.target)
|
|
self.assertNotEqual(id(p2.node), id(node))
|
|
self.assertNotEqual(id(p2.tracer), id(tracer))
|
|
|
|
def test_concrete_arg_none_assert(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, val=None):
|
|
return x if val is None else x + val
|
|
|
|
f = Foo()
|
|
traced = torch.fx.symbolic_trace(f, concrete_args={'val' : None})
|
|
with self.assertRaisesRegex(AssertionError, 'val has been specialized to have value None'):
|
|
traced(torch.randn(5), torch.randn(5))
|
|
|
|
x = torch.randn(5)
|
|
torch.testing.assert_close(traced(x), f(x))
|
|
|
|
def test_trace_multiple_funcs(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
def minus_forward(self, x, y):
|
|
return x - y
|
|
|
|
def multiply_forward(self, x, y):
|
|
return x * y
|
|
|
|
f = Foo()
|
|
x, y = torch.randn(5), torch.randn(5)
|
|
|
|
print(torch.__version__)
|
|
|
|
tracer = Tracer()
|
|
torch.testing.assert_close(GraphModule(f, tracer.trace(f))(x, y), f(x, y))
|
|
|
|
tracer.traced_func_name = "minus_forward"
|
|
torch.testing.assert_close(
|
|
GraphModule(f, tracer.trace(f))(x, y),
|
|
f.minus_forward(x, y),
|
|
)
|
|
|
|
tracer.traced_func_name = "multiply_forward"
|
|
torch.testing.assert_close(
|
|
GraphModule(f, tracer.trace(f))(x, y),
|
|
f.multiply_forward(x, y),
|
|
)
|
|
|
|
tracer.traced_func_name = "add_forward"
|
|
with self.assertRaisesRegex(AssertionError, "doesn't exist in"):
|
|
tracer.trace(f)
|
|
|
|
def test_graph_unique_names(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
m = M()
|
|
g = symbolic_trace(m).graph
|
|
new_g = torch.fx.Graph()
|
|
val_map : Dict[Node, Node] = {}
|
|
output_val = new_g.graph_copy(g, val_map)
|
|
t = Proxy(output_val)
|
|
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
|
|
new_g.output((t + t).node)
|
|
gm = GraphModule(m, new_g)
|
|
seen_names : Set[str] = set()
|
|
for node in gm.graph.nodes:
|
|
assert node.name not in seen_names
|
|
seen_names.add(node.name)
|
|
|
|
def test_stack_traces(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
tracer = torch.fx.Tracer()
|
|
tracer.record_stack_traces = True
|
|
|
|
graph = tracer.trace(M())
|
|
# saving the original list because we will insert new nodes as a part of a test
|
|
orig_graph_nodes = list(graph.nodes)
|
|
for node in orig_graph_nodes:
|
|
if node.op == 'output':
|
|
continue
|
|
self.assertTrue(node.stack_trace is not None)
|
|
assert 'test_fx.py' in node.stack_trace
|
|
|
|
# verify that copying the node does not lose the stack trace
|
|
new_node = graph.node_copy(node)
|
|
self.assertTrue(new_node.stack_trace is not None)
|
|
assert 'test_fx.py' in new_node.stack_trace
|
|
|
|
def test_stack_traces_with_transformer(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
tracer = torch.fx.Tracer()
|
|
tracer.record_stack_traces = True
|
|
|
|
graph = tracer.trace(M())
|
|
gm = GraphModule(tracer.root, graph)
|
|
new_gm = Transformer(gm).transform()
|
|
|
|
# nodes after Transformer should still preserve the original node's stack trace
|
|
for node in new_gm.graph.nodes:
|
|
if node.op in {'placeholder', 'output'}:
|
|
continue
|
|
self.assertTrue(node.stack_trace is not None)
|
|
assert 'test_fx.py' in node.stack_trace
|
|
|
|
def test_lineno_map(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
a = torch.sin(a)
|
|
b = torch.cos(b)
|
|
return a + b
|
|
|
|
tracer = torch.fx.Tracer()
|
|
graph = tracer.trace(M())
|
|
gm = GraphModule(tracer.root, graph)
|
|
expected = {1: 2, 2: 3, 3: 4, 4: 5}
|
|
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
|
|
|
# test custom codegen
|
|
def transform_code(code):
|
|
return ["print('hello!')\n", *code]
|
|
gm.graph.on_generate_code(lambda _: transform_code)
|
|
gm.recompile()
|
|
expected = {2: 2, 3: 3, 4: 4, 5: 5}
|
|
self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items())))
|
|
|
|
def test_graph_unique_names_manual(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
a : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
|
|
c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
|
|
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
|
|
graph.output(d)
|
|
graph2 = torch.fx.Graph()
|
|
val_map : Dict[Node, Node] = {}
|
|
graph2.graph_copy(graph, val_map)
|
|
seen_names : Set[str] = set()
|
|
for node in graph2.nodes:
|
|
assert node.name not in seen_names
|
|
seen_names.add(node.name)
|
|
|
|
def test_unpack(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
c, d = a
|
|
return c + d + b
|
|
|
|
a = (torch.rand(1), torch.rand(1))
|
|
b = torch.rand(1)
|
|
m = M()
|
|
self.checkGraphModule(m, (a, b))
|
|
|
|
def test_native_callable(self):
|
|
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
|
|
raise unittest.SkipTest("non-portable load_library call used in test")
|
|
# This test exercises the case where we use FX to translate from Python
|
|
# code to some native callable object
|
|
#
|
|
# For the purposes of testing, we use ElementwiseInterpreter defined
|
|
# in test_custom_class.cpp.
|
|
#
|
|
# We test that we can
|
|
# 1) Construct a native callable from FX IR
|
|
# 2) Construct a drop-in replacement module that delegates to the
|
|
# native callable rather than the original code
|
|
# 3) Run both the original code and native callable wrapper with
|
|
# equivalent results
|
|
# 4) TorchScript compile the native callable wrapper and confirm
|
|
# equivalent results with the reference
|
|
# 5) TorchScript serialize and deserialize the native callable
|
|
# and confirm equivalent results with the reference
|
|
|
|
# We use this simple Module as a reference computation
|
|
class MySimpleMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
return 3.0 * x + x
|
|
|
|
msm = MySimpleMod()
|
|
|
|
# This is what a lowering pass might look like: a function that takes
|
|
# a valid nn.Module, symbolically traces it, lowers the Module to some
|
|
# representation, and wraps that representation up into another
|
|
# nn.Module instance that handles dispatch to the compiled/lowered code.
|
|
def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
|
|
# ===== Stage 1: Symbolic trace the module =====
|
|
mod = symbolic_trace(orig_mod)
|
|
|
|
# ===== Stage 2: Lower GraphModule representation to the C++
|
|
# interpreter's instruction format ======
|
|
instructions = []
|
|
constant_idx = 0
|
|
constants = {}
|
|
fn_input_names = []
|
|
|
|
target_to_name = {
|
|
operator.add : "add",
|
|
operator.mul : "mul"
|
|
}
|
|
|
|
output_node : Optional[Node] = None
|
|
# For each instruction, create a triple
|
|
# (instruction_name : str, inputs : List[str], output : str)
|
|
# to feed into the C++ interpreter
|
|
for n in mod.graph.nodes:
|
|
target, args, out_name = n.target, n.args, n.name
|
|
assert len(n.kwargs) == 0, "kwargs currently not supported"
|
|
|
|
if n.op == 'placeholder':
|
|
# Placeholders specify function argument names. Save these
|
|
# for later when we generate the wrapper GraphModule
|
|
fn_input_names.append(target)
|
|
elif n.op == 'call_function':
|
|
assert target in target_to_name, "Unsupported call target " + target
|
|
arg_names = []
|
|
for arg in args:
|
|
if not isinstance(arg, Node):
|
|
# Pull out constants. These constants will later be
|
|
# fed to the interpreter C++ object via add_constant()
|
|
arg_name = f'constant_{constant_idx}'
|
|
constants[arg_name] = torch.tensor(
|
|
[arg] if isinstance(arg, numbers.Number) else arg)
|
|
arg_names.append(arg_name)
|
|
constant_idx += 1
|
|
else:
|
|
arg_names.append(arg.name)
|
|
instructions.append((target_to_name[target], arg_names, out_name))
|
|
elif n.op == 'output':
|
|
if output_node is not None:
|
|
raise RuntimeError('Multiple output nodes!')
|
|
output_node = n
|
|
else:
|
|
raise RuntimeError('Unsupported opcode ' + n.op)
|
|
|
|
interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
|
|
# Load constants
|
|
for k, v in constants.items():
|
|
interpreter.add_constant(k, v)
|
|
# Specify names for positional input arguments
|
|
interpreter.set_input_names(fn_input_names)
|
|
# Load instructions
|
|
interpreter.set_instructions(instructions)
|
|
# Specify name for single output
|
|
assert isinstance(output_node.args[0], torch.fx.Node)
|
|
interpreter.set_output_name(output_node.args[0].name)
|
|
|
|
# ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
|
|
class WrapperModule(torch.nn.Module):
|
|
def __init__(self, interpreter):
|
|
super().__init__()
|
|
self.interpreter = interpreter
|
|
|
|
wrapper = WrapperModule(interpreter)
|
|
|
|
# Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
|
|
# 3) Returns the speficied return value
|
|
|
|
# FIXME: The following code could be greatly simplified by symbolic_trace'ing
|
|
# the wrapper with a Tracer that considers the Wrapper instance a root
|
|
# module, however, I can't get `__call__` exposed on TorchBind classes
|
|
# without it messing up Python `hasattr` for some reason. More digging
|
|
# into CPython's implementation of hasattr is probably in order...
|
|
|
|
graph = torch.fx.Graph()
|
|
# Add placeholders for fn inputs
|
|
placeholder_nodes = []
|
|
for name in fn_input_names:
|
|
placeholder_nodes.append(graph.create_node('placeholder', name))
|
|
|
|
# Get the interpreter object
|
|
interpreter_node = graph.create_node('get_attr', 'interpreter')
|
|
|
|
# Add a node to call the interpreter instance
|
|
output_node = graph.create_node(
|
|
op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
|
|
|
|
# Register output
|
|
graph.output(output_node)
|
|
|
|
graph.lint()
|
|
|
|
# Return final GraphModule!!!
|
|
return GraphModule(wrapper, graph)
|
|
|
|
# Lower GraphModule to C++ interpreter
|
|
lowered = lower_to_elementwise_interpreter(msm)
|
|
|
|
# Compare correctness with original module
|
|
x = torch.rand(3, 4)
|
|
ref_out = msm(x)
|
|
test_out = lowered(x)
|
|
torch.testing.assert_close(test_out, ref_out)
|
|
|
|
# Test TorchScript compilation
|
|
scripted_lowered = torch.jit.script(lowered)
|
|
script_out = scripted_lowered(x)
|
|
torch.testing.assert_close(script_out, ref_out)
|
|
|
|
# Test TorchScript ser/de
|
|
import_copy = self.getExportImportCopy(scripted_lowered)
|
|
imported_out = import_copy(x)
|
|
torch.testing.assert_close(imported_out, ref_out)
|
|
|
|
def test_reserved_getattr(self):
|
|
"""Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
|
|
class M(torch.nn.Module):
|
|
def forward(self, a):
|
|
return a.foo.bar.baz
|
|
|
|
m = M()
|
|
m_g = symbolic_trace(m)
|
|
m_g.graph.lint()
|
|
for node in m_g.graph.nodes:
|
|
self.assertTrue(node.name != "getattr")
|
|
|
|
@unittest.skip("Hotfix for SEV remediation")
|
|
def test_trace_buffer_slice(self):
|
|
bs, d_hid = 10, 23
|
|
|
|
class ExampleCode(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.lin = torch.nn.Linear(d_hid, d_hid)
|
|
self.buffer = torch.nn.Buffer(torch.randn(bs + 100, d_hid))
|
|
|
|
def forward(self, x):
|
|
x = torch.mm(x, self.mm_param)
|
|
skip_connection = x
|
|
x = torch.relu(x)
|
|
x = torch.mm(x, self.mm_param) + self.buffer[:x.shape[0]]
|
|
x = self.lin(x)
|
|
x = torch.relu(x)
|
|
x = x + skip_connection
|
|
x = torch.mm(x, self.mm_param2)
|
|
x = self.lin(x)
|
|
return x
|
|
|
|
ec = ExampleCode()
|
|
|
|
traced = torch.fx.symbolic_trace(ec)
|
|
|
|
x = torch.randn(bs, d_hid)
|
|
torch.testing.assert_close(ec(x), traced(x))
|
|
|
|
def test_node_tagging(self):
|
|
class TaggingTracer(Tracer):
|
|
def create_node(self, kind : str, target : Union[str, Callable],
|
|
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
|
|
type_expr : Optional[Any] = None) -> Node:
|
|
n = super().create_node(kind, target, args, kwargs, name)
|
|
n.tag = 'foo'
|
|
return n
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = M()
|
|
g = TaggingTracer().trace(m)
|
|
g.lint()
|
|
for n in g.nodes:
|
|
self.assertTrue(hasattr(n, 'tag'))
|
|
self.assertEqual(n.tag, 'foo')
|
|
|
|
def test_tensor_attribute(self):
|
|
class TensorAttribute(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.tensor = torch.rand(3, 4)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.tensor)
|
|
|
|
ta = TensorAttribute()
|
|
traced = symbolic_trace(ta)
|
|
traced(torch.rand(4, 4))
|
|
|
|
class WrapperForQualname(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.ta = TensorAttribute()
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, self.ta.tensor)
|
|
|
|
wfq = WrapperForQualname()
|
|
traced2 = symbolic_trace(wfq)
|
|
traced2.graph.lint()
|
|
traced2(torch.rand(4, 4))
|
|
|
|
def test_tensor_attribute_coalseced(self):
|
|
|
|
def count_attrs(fx_module):
|
|
targets = set()
|
|
for node in traced.graph.nodes:
|
|
if node.op == 'get_attr':
|
|
targets.add(node.target)
|
|
return len(targets)
|
|
|
|
val = torch.tensor(5)
|
|
|
|
def f(x):
|
|
return x + val + val
|
|
traced = symbolic_trace(f)
|
|
traced.graph.lint()
|
|
self.assertEqual(count_attrs(traced), 1)
|
|
|
|
val2 = torch.tensor(5)
|
|
|
|
def f(x):
|
|
val = torch.tensor(5)
|
|
return x + val + val2
|
|
|
|
traced = symbolic_trace(f)
|
|
traced.graph.lint()
|
|
self.assertEqual(count_attrs(traced), 2)
|
|
|
|
def test_symbolic_trace_sequential(self):
|
|
class Simple(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.neg(x)
|
|
|
|
seq = torch.nn.Sequential(
|
|
Simple(),
|
|
Simple(),
|
|
Simple()
|
|
)
|
|
traced = symbolic_trace(seq)
|
|
traced.graph.lint()
|
|
x = torch.rand(3, 4)
|
|
self.assertEqual(traced(x), seq(x))
|
|
|
|
def test_tensor_constant(self):
|
|
class ConstTensor(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.nn.functional.linear(x, torch.zeros(3, 4))
|
|
|
|
ct = ConstTensor()
|
|
traced = symbolic_trace(ct)
|
|
traced.graph.lint()
|
|
traced(torch.rand(4, 4))
|
|
|
|
def test_pickle_graphmodule(self):
|
|
class Nested(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.st = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.st(x)
|
|
|
|
n = Nested()
|
|
traced = symbolic_trace(n)
|
|
traced.graph.lint()
|
|
pickled = pickle.dumps(traced)
|
|
loaded = pickle.loads(pickled)
|
|
loaded.graph.lint()
|
|
x = torch.rand(3, 4)
|
|
self.assertEqual(loaded(x), traced(x))
|
|
|
|
def test_pickle_custom_import(self):
|
|
graph = torch.fx.Graph()
|
|
a = graph.placeholder('x')
|
|
b = graph.placeholder('y')
|
|
c = graph.call_function(a_non_torch_leaf, (a, b))
|
|
d = graph.call_function(torch.sin, (c,))
|
|
graph.output(d)
|
|
gm = GraphModule(torch.nn.Module(), graph)
|
|
pickled = pickle.dumps(gm)
|
|
loaded = pickle.loads(pickled)
|
|
loaded.graph.lint()
|
|
x, y = torch.rand(1), torch.rand(1)
|
|
self.assertEqual(loaded(x, y), gm(x, y))
|
|
|
|
def test_all_input_nodes(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
a : torch.fx.Node = graph.placeholder('x')
|
|
b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
|
|
c : torch.fx.Node = graph.get_attr('y_attr')
|
|
d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
|
|
e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
|
|
graph.output(e)
|
|
graph.lint()
|
|
|
|
self.assertEqual(b.all_input_nodes, [a])
|
|
self.assertEqual(c.all_input_nodes, [])
|
|
self.assertEqual(d.all_input_nodes, [b, c])
|
|
self.assertEqual(e.all_input_nodes, [d])
|
|
|
|
def test_deepcopy_graphmodule_with_transform(self):
|
|
st = SimpleTest()
|
|
traced = symbolic_trace(st)
|
|
traced.graph.lint()
|
|
|
|
def transform(traced):
|
|
new_graph = torch.fx.Graph()
|
|
val_map : Dict[Node, Node] = {}
|
|
output_value = new_graph.graph_copy(traced.graph, val_map)
|
|
relu_out = new_graph.create_node(
|
|
op='call_method', target='neg', args=(output_value,), kwargs={})
|
|
new_graph.output(relu_out)
|
|
return GraphModule(traced, new_graph)
|
|
transformed = transform(traced)
|
|
transformed.graph.lint()
|
|
copied = copy.deepcopy(transformed)
|
|
self.assertNotEqual(id(type(transformed)), id(type(copied)))
|
|
x = torch.randn(3, 4)
|
|
self.assertEqual(copied(x), transformed(x))
|
|
|
|
def test_deepcopy_with_submods_params(self):
|
|
class Bar(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
|
|
def forward(self, x):
|
|
return torch.relu(x) + self.param
|
|
|
|
class Baz(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.bar = Bar()
|
|
|
|
def forward(self, x):
|
|
return self.bar(x) - self.param
|
|
|
|
baz = Baz()
|
|
traced = symbolic_trace(baz)
|
|
traced.graph.lint()
|
|
copied = copy.deepcopy(traced)
|
|
copied.graph.lint()
|
|
|
|
def test_deepcopy_graph_with_tracer_cls(self):
|
|
class TestTracer(Tracer):
|
|
def is_leaf_module(self, module, name):
|
|
return True
|
|
|
|
g = Graph(tracer_cls=TestTracer)
|
|
x = g.placeholder("x")
|
|
g.output(x)
|
|
|
|
h = copy.deepcopy(g)
|
|
self.assertIsNotNone(h._tracer_cls)
|
|
self.assertTrue(g._tracer_cls == h._tracer_cls)
|
|
|
|
def test_unpack_list_better_error(self):
|
|
class SomeArgs(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return torch.rand(3, 4)
|
|
|
|
class UnpacksList(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sa = SomeArgs()
|
|
|
|
def forward(self, x : list):
|
|
return self.sa(*x)
|
|
|
|
ul = UnpacksList()
|
|
with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
|
|
symbolic_trace(ul)
|
|
|
|
def test_unpack_dict_better_error(self):
|
|
class SomeKwargs(torch.nn.Module):
|
|
def forward(self, x=3, y=4):
|
|
return torch.rand(3, 4)
|
|
|
|
class UnpacksDict(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sk = SomeKwargs()
|
|
|
|
def forward(self, x : dict):
|
|
return self.sk(**x)
|
|
|
|
ud = UnpacksDict()
|
|
with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
|
|
symbolic_trace(ud)
|
|
|
|
def test_pretty_print_targets(self):
|
|
# Test that Graph pretty-print prints friendly name for targets
|
|
# in `operator` and `builtins`
|
|
|
|
class SomeMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.add(x.foo + x.bar, 3.0)
|
|
|
|
traced = symbolic_trace(SomeMod())
|
|
graph_str = str(traced.graph)
|
|
self.assertIn('builtins.getattr', graph_str)
|
|
self.assertIn('operator.add', graph_str)
|
|
self.assertIn('torch.add', graph_str)
|
|
|
|
def test_pretty_print_node(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param: torch.nn.Parameter = torch.nn.Parameter(
|
|
torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x: torch.Tensor, y: int = 2):
|
|
return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
|
|
|
|
traced = symbolic_trace(M())
|
|
|
|
all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
|
|
|
|
FileCheck().check("x").check("placeholder") \
|
|
.check("y").check("placeholder") \
|
|
.check("getitem").check("call_function") \
|
|
.check("param").check("get_attr") \
|
|
.check("add").check("call_function") \
|
|
.check("linear").check("call_module") \
|
|
.check("clamp").check("call_method") \
|
|
.run(all_formatted)
|
|
|
|
def test_script_tensor_constant(self):
|
|
# TorchScript seems to ignore attributes that start with `__`.
|
|
# We used to call anonymous Tensor values `__tensor_constant*`, but
|
|
# they were getting ignored by script. Now they're called
|
|
# `_tensor_constant*`
|
|
class IHaveATensorConstant(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + torch.rand(3, 4)
|
|
|
|
traced = torch.fx.symbolic_trace(IHaveATensorConstant())
|
|
torch.jit.script(traced)
|
|
|
|
def test_autowrap_functions(self):
|
|
class AutowrapFnTest(torch.nn.Module):
|
|
def forward(self, x):
|
|
return fx_int(x.shape[0] / 2)
|
|
|
|
class AutowrapFnTest2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
|
|
|
|
# Check function(s) are wrapped
|
|
# `int` would normally throw a TypeError as argument can't be `Proxy`
|
|
tracer = Tracer(autowrap_functions=(fx_int,))
|
|
graph = tracer.trace(AutowrapFnTest())
|
|
traced = GraphModule(tracer.root, graph, 'test')
|
|
tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
|
|
tracer_2.trace(AutowrapFnTest2())
|
|
|
|
# Test scriptability
|
|
traced_scripted = torch.jit.script(traced)
|
|
self.assertEqual(traced_scripted(torch.rand(4)), 2)
|
|
|
|
def test_tuple_no_subscript(self):
|
|
def foo(x : Tuple):
|
|
return x[0]
|
|
|
|
traced = torch.fx.symbolic_trace(foo)
|
|
x = (torch.randn(5, 3),)
|
|
torch.testing.assert_close(traced(x), x[0])
|
|
|
|
bio = io.BytesIO()
|
|
|
|
torch.save(traced, bio)
|
|
|
|
bio.seek(0)
|
|
|
|
# weights_only=False as this loads a GraphModule
|
|
# GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
|
|
loaded = torch.load(bio, weights_only=False)
|
|
|
|
torch.testing.assert_close(loaded(x), x[0])
|
|
|
|
def test_torch_fx_len(self):
|
|
class FXLenTest(torch.nn.Module):
|
|
def forward(self, x):
|
|
return len(x)
|
|
|
|
traced = symbolic_trace(FXLenTest())
|
|
self.assertEqual(traced(torch.rand(3, 4)), 3)
|
|
|
|
# Test scriptability
|
|
scripted = torch.jit.script(FXLenTest())
|
|
self.assertEqual(scripted(torch.rand(3)), 3)
|
|
|
|
traced_scripted = torch.jit.script(traced)
|
|
self.assertEqual(traced_scripted(torch.rand(3)), 3)
|
|
|
|
# Test non-proxy len
|
|
class FXLenTest2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.l = [3, 4, 5]
|
|
|
|
def forward(self, x):
|
|
return x + len(self.l)
|
|
|
|
traced2 = symbolic_trace(FXLenTest2())
|
|
inp = torch.rand(3, 4)
|
|
self.assertEqual(traced2(inp), inp + 3.0)
|
|
self.assertIs(len, builtins.len)
|
|
|
|
def test_torch_fx_getattr(self):
|
|
class FXGetattrTest(torch.nn.Module):
|
|
def forward(self, x):
|
|
return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
|
|
|
|
traced = symbolic_trace(FXGetattrTest())
|
|
self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
|
|
|
|
def test_sqrt(self):
|
|
class Sqrt1(torch.nn.Module):
|
|
def forward(self, x):
|
|
return sqrt(x.size(0))
|
|
|
|
class Sqrt2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return math.sqrt(x.size(0))
|
|
|
|
class Sqrt3(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + math.sqrt(2) + sqrt(2)
|
|
|
|
self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
|
|
self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
|
|
self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
|
|
self.assertIs(sqrt, _sqrt)
|
|
self.assertIs(math.sqrt, _sqrt)
|
|
|
|
def test_torch_custom_ops(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a):
|
|
b = torch.ops.aten.sigmoid(a)
|
|
c = torch.ops.aten.cat([a, b])
|
|
return torch.ops.aten.cat((c, c))
|
|
m = M()
|
|
input = torch.randn(3)
|
|
ref_out = m(input)
|
|
gm = symbolic_trace(m)
|
|
gm.graph.lint()
|
|
out = gm(input)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
def test_torch_op_overloads(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a):
|
|
b = torch.ops.aten.add.Tensor(a, a)
|
|
return b
|
|
m = M()
|
|
input = torch.randn(3)
|
|
ref_out = m(input)
|
|
gm = symbolic_trace(m)
|
|
gm.graph.lint()
|
|
out = gm(input)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
for node in gm.graph.nodes:
|
|
if node.op == 'call_function':
|
|
assert isinstance(node.target, torch._ops.OpOverload)
|
|
assert node.target.__name__ == 'add.Tensor'
|
|
|
|
def test_pickle_torch_custom_ops(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a):
|
|
b = torch.ops.aten.sigmoid(a)
|
|
c = torch.ops.aten.cat([a, b])
|
|
return torch.ops.aten.cat((c, c))
|
|
m = M()
|
|
input = torch.randn(3)
|
|
ref_out = m(input)
|
|
gm = symbolic_trace(m)
|
|
gm.graph.lint()
|
|
pickled = pickle.dumps(gm)
|
|
loaded = pickle.loads(pickled)
|
|
self.assertEqual(loaded(input), gm(input))
|
|
|
|
def test_pretty_print(self):
|
|
st = SimpleTest()
|
|
traced = symbolic_trace(st)
|
|
traced.graph.lint()
|
|
printed = str(traced)
|
|
assert 'SimpleTest()' in printed
|
|
assert 'torch.relu' in printed
|
|
|
|
def test_pretty_print_graph(self):
|
|
class KwargPrintTest(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.squeeze(x + 3.0, dim=2)
|
|
st = KwargPrintTest()
|
|
traced = symbolic_trace(st)
|
|
traced.graph.lint()
|
|
stringed = str(traced.graph)
|
|
for s in ['args', 'kwargs', 'num_users']:
|
|
assert s in stringed
|
|
|
|
def test_custom_proxy_type(self):
|
|
class TensorPair:
|
|
def __init__(self, left, right):
|
|
self.left, self.right = left, right
|
|
|
|
def add(self, other):
|
|
l = self.left + other.left
|
|
r = self.right + other.right
|
|
return TensorPair(l, r)
|
|
|
|
def mul(self, other):
|
|
l = self.left * other.left
|
|
r = self.right * other.right
|
|
return TensorPair(l, r)
|
|
|
|
def use_tensor_pair(x : TensorPair, y : TensorPair):
|
|
s = x.add(y)
|
|
return s.mul(x)
|
|
|
|
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
|
|
y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
|
|
|
|
ref_out = use_tensor_pair(x, y)
|
|
|
|
traced = symbolic_trace(use_tensor_pair)
|
|
|
|
traced_out = traced(x, y)
|
|
self.assertEqual(traced_out.left, ref_out.left)
|
|
self.assertEqual(traced_out.right, ref_out.right)
|
|
|
|
def test_custom_proxy_type_literal(self):
|
|
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
|
|
def __init__(self, left, right):
|
|
self.left, self.right = left, right
|
|
|
|
def add(self, other):
|
|
l = self.left + other.left
|
|
r = self.right + other.right
|
|
return TensorPair(l, r)
|
|
|
|
def mul(self, other):
|
|
l = self.left * other.left
|
|
r = self.right * other.right
|
|
return TensorPair(l, r)
|
|
|
|
def use_tensor_pair_literal(x : TensorPair):
|
|
s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
|
|
return s.mul(x)
|
|
|
|
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
|
|
|
|
ref_out = use_tensor_pair_literal(x)
|
|
|
|
traced = symbolic_trace(use_tensor_pair_literal)
|
|
|
|
traced_out = traced(x)
|
|
self.assertEqual(traced_out.left, ref_out.left)
|
|
self.assertEqual(traced_out.right, ref_out.right)
|
|
|
|
def test_custom_proxy_dynamic_value(self):
|
|
class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
|
|
def __init__(self, left, right):
|
|
self.left, self.right = left, right
|
|
|
|
def add(self, other):
|
|
l = self.left + other.left
|
|
r = self.right + other.right
|
|
return TensorPair(l, r)
|
|
|
|
def mul(self, other):
|
|
l = self.left * other.left
|
|
r = self.right * other.right
|
|
return TensorPair(l, r)
|
|
|
|
def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
|
|
s = x.add(TensorPair(y, y))
|
|
return s.mul(x)
|
|
|
|
x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
|
|
y = torch.randn(5, 3)
|
|
ref_out = use_tensor_pair_ctor(x, y)
|
|
|
|
traced = symbolic_trace(use_tensor_pair_ctor)
|
|
|
|
traced_out = traced(x, y)
|
|
self.assertEqual(traced_out.left, ref_out.left)
|
|
self.assertEqual(traced_out.right, ref_out.right)
|
|
|
|
def test_custom_proxy_input_dependent_control_flow(self):
|
|
class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
|
|
def __init__(self, inp):
|
|
if inp.sum() == 0:
|
|
self.is_zero = True
|
|
self.tensor = torch.tensor([])
|
|
else:
|
|
self.is_zero = False
|
|
self.tensor = inp
|
|
|
|
def add(self, other):
|
|
if self.is_zero:
|
|
return ZeroTensor(other.tensor)
|
|
elif other.is_zero:
|
|
return self
|
|
|
|
def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
|
|
return ZeroTensor(x + y)
|
|
|
|
x, y = torch.randn(5, 3), torch.randn(5, 3)
|
|
|
|
ref_out = use_zero_tensor(x, y)
|
|
|
|
traced = symbolic_trace(use_zero_tensor)
|
|
|
|
traced_out = traced(x, y)
|
|
|
|
self.assertEqual(traced_out.is_zero, ref_out.is_zero)
|
|
self.assertEqual(traced_out.tensor, ref_out.tensor)
|
|
|
|
def test_graph_fns(self):
|
|
g = Graph()
|
|
a = g.placeholder('a')
|
|
b = g.call_module('linear', (a,))
|
|
c = g.get_attr('bias')
|
|
d = g.call_method('add', (b, c))
|
|
e = g.call_function(torch.sin, (d,))
|
|
g.output(e)
|
|
mod = torch.nn.Module()
|
|
mod.linear = torch.nn.Linear(3, 4)
|
|
mod.bias = torch.rand(4)
|
|
gm = GraphModule(mod, g)
|
|
gm.graph.lint()
|
|
input = torch.rand(3)
|
|
r = gm(input)
|
|
ref = torch.sin(mod.linear(input) + mod.bias)
|
|
self.assertEqual(r, ref)
|
|
|
|
def test_remove_uses(self):
|
|
g : torch.fx.Graph = Graph()
|
|
x : torch.fx.Node = g.placeholder('x')
|
|
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
|
|
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
|
|
g.output(neg)
|
|
|
|
neg.replace_all_uses_with(relu)
|
|
g.erase_node(neg)
|
|
|
|
self.assertTrue(neg not in relu.users)
|
|
|
|
def test_remove_uses_with_custom_filter(self):
|
|
g : torch.fx.Graph = Graph()
|
|
x : torch.fx.Node = g.placeholder('x')
|
|
relu : torch.fx.Node = g.call_function(torch.relu, (x,))
|
|
neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
|
|
g.output(neg)
|
|
|
|
neg.replace_all_uses_with(relu, lambda x: x != neg)
|
|
|
|
self.assertTrue(neg in relu.users)
|
|
|
|
def test_nonetype_annotation(self):
|
|
eb = torch.nn.EmbeddingBag(3, 4)
|
|
symbolic_trace(eb)
|
|
|
|
def test_pickle_nonetype_annotation(self):
|
|
eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
|
|
traced = symbolic_trace(eb)
|
|
pickled = pickle.dumps(traced)
|
|
loaded = pickle.loads(pickled)
|
|
loaded.graph.lint()
|
|
input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
|
|
offsets = torch.LongTensor([0, 4])
|
|
self.assertEqual(loaded(input, offsets), traced(input, offsets))
|
|
|
|
def test_return_tuple(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
return (x, x + x)
|
|
|
|
original = M()
|
|
traced = symbolic_trace(original)
|
|
self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
|
|
|
|
def test_construct_root_dict(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
a : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
|
|
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
|
|
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
|
|
graph.output(d)
|
|
|
|
linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
|
|
add_param : torch.Tensor = torch.rand(3, 4)
|
|
gm : torch.fx.GraphModule = torch.fx.GraphModule(
|
|
{'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
|
|
gm.graph.lint()
|
|
|
|
assert 'self.foo.bar.baz' in gm.code
|
|
|
|
x : torch.Tensor = torch.rand(3, 3)
|
|
out : torch.Tensor = gm(x)
|
|
ref_out : torch.Tensor = linear_mod(x) + add_param
|
|
self.assertEqual(out, ref_out)
|
|
|
|
def test_symbolic_trace_assert(self):
|
|
|
|
class AssertsTensorShape(torch.nn.Module):
|
|
def forward(self, x):
|
|
torch._assert(x.shape[1] > 4, "assert_foobar")
|
|
return x
|
|
|
|
m = AssertsTensorShape()
|
|
# verify traceability
|
|
traced = symbolic_trace(m)
|
|
# verify assertion on traced model works correctly at runtime
|
|
traced(torch.rand(4, 5))
|
|
with self.assertRaisesRegex(AssertionError, "assert_foobar"):
|
|
traced(torch.rand(4, 3))
|
|
# verify the symbolically traced module is scriptable
|
|
ms = torch.jit.script(m)
|
|
with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
|
|
ms(torch.rand(4, 3))
|
|
|
|
def test_fx_create_arg(self):
|
|
class CustomArgObject:
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def __fx_create_arg__(self, tracer: torch.fx.Tracer):
|
|
return tracer.create_node(
|
|
"call_function",
|
|
CustomArgObject,
|
|
args=(
|
|
tracer.create_arg(self.x),
|
|
tracer.create_arg(self.y),
|
|
),
|
|
kwargs={},
|
|
)
|
|
|
|
class HasCustomArgObjectWhenLeaf(torch.nn.Module):
|
|
def forward(self, o: CustomArgObject):
|
|
# Not normally traceable; good reason to make
|
|
# this module a leaf.
|
|
for x in o.x:
|
|
o.y += x
|
|
return o.y
|
|
|
|
class Root(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.inner = HasCustomArgObjectWhenLeaf()
|
|
|
|
def forward(self, x, y):
|
|
o = CustomArgObject(x, y)
|
|
return self.inner(o)
|
|
|
|
class CreateArgTracer(torch.fx.Tracer):
|
|
def is_leaf_module(self, m, module_qualified_name):
|
|
return type(m) is HasCustomArgObjectWhenLeaf
|
|
|
|
m = Root()
|
|
graph = CreateArgTracer().trace(m)
|
|
gm = torch.fx.GraphModule(m, graph)
|
|
assert "CustomArgObject(" in gm.code
|
|
|
|
def test_trace_fn_constant(self):
|
|
some_constant = torch.rand(3, 4)
|
|
|
|
def add_const(x):
|
|
return some_constant + x
|
|
|
|
traced = symbolic_trace(add_const)
|
|
|
|
input = torch.rand(3, 4)
|
|
self.assertEqual(traced(input), add_const(input))
|
|
|
|
def test_copy_no_remap(self):
|
|
traced = symbolic_trace(SimpleTest())
|
|
g = traced.graph
|
|
copied = torch.fx.Graph()
|
|
for node in g.nodes:
|
|
copied.node_copy(node)
|
|
with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
|
|
copied.lint()
|
|
|
|
def test_wrong_topo(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
a : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
|
|
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
|
|
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
|
|
graph.output(d)
|
|
nodes = list(graph.nodes)
|
|
nodes[3].append(nodes[2])
|
|
with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
|
|
graph.lint()
|
|
|
|
def test_wrong_target_type(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
with self.assertRaises(ValueError):
|
|
n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
|
|
args=(), kwargs={})
|
|
|
|
def test_example_shape_prop(self):
|
|
class TestCase(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.randn(3, 4)
|
|
self.submod = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return torch.neg(self.submod(x.relu() + self.attr))
|
|
tc = TestCase()
|
|
tc_traced = symbolic_trace(tc)
|
|
ref_out = tc_traced(torch.rand(3, 4))
|
|
shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
|
|
|
|
# Make sure we're testing all opcodes
|
|
opcodes = set()
|
|
output_shape : Optional[torch.Shape] = None
|
|
output_stride : Optional[Tuple[int]] = None
|
|
for node in tc_traced.graph.nodes:
|
|
opcodes.add(node.op)
|
|
if node.op == 'output':
|
|
output_shape = node.args[0].meta['tensor_meta'].shape
|
|
output_stride = node.args[0].meta['tensor_meta'].stride
|
|
self.assertEqual(opcodes, {'placeholder', 'get_attr', 'call_function', 'call_method',
|
|
'call_module', 'output'})
|
|
|
|
# Test shape propagation and make sure results match actual
|
|
self.assertEqual(output_shape, ref_out.shape)
|
|
self.assertEqual(output_stride, ref_out.stride())
|
|
|
|
def test_shape_prop_layout(self):
|
|
class ConvTest(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv_mod = torch.nn.Conv2d(5, 5, 3)
|
|
|
|
def forward(self, x):
|
|
return self.conv_mod(x)
|
|
|
|
# contiguous layout
|
|
test_mod = ConvTest()
|
|
traced = symbolic_trace(test_mod)
|
|
x = torch.randn(5, 5, 224, 224)
|
|
shape_prop.ShapeProp(traced).propagate(x)
|
|
|
|
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
|
|
for node in traced.graph.nodes)
|
|
|
|
x_channels_last = x.contiguous(memory_format=torch.channels_last)
|
|
traced.to(memory_format=torch.channels_last)
|
|
shape_prop.ShapeProp(traced).propagate(x_channels_last)
|
|
for node in traced.graph.nodes:
|
|
# NB: the implementation of conv may not preserve the memory format,
|
|
# unfortunately. The best we can do is just check that the placeholder
|
|
# node is channels-last
|
|
if node.op in {'placeholder'}:
|
|
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
|
|
|
|
def test_shape_prop_aggregate(self):
|
|
class ReturnTwo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return (3, torch.sum(x))
|
|
|
|
class UnderTest(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.rt = ReturnTwo()
|
|
|
|
def forward(self, x):
|
|
return self.rt(x)
|
|
|
|
ut = UnderTest()
|
|
|
|
class RTTracer(torch.fx.Tracer):
|
|
def is_leaf_module(self, m, module_qualified_name):
|
|
return type(m) is ReturnTwo
|
|
|
|
graph = RTTracer().trace(ut)
|
|
mod = torch.fx.GraphModule(ut, graph)
|
|
|
|
shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
|
|
|
|
for node in mod.graph.nodes:
|
|
if node.op == 'call_module':
|
|
assert 'tensor_meta' in node.meta
|
|
tensor_meta = node.meta['tensor_meta']
|
|
assert tensor_meta[0] == 3
|
|
assert tensor_meta[1].shape == torch.Size([])
|
|
|
|
def test_shape_prop_layout_3d(self):
|
|
class ConvTest3d(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv_mod = torch.nn.Conv3d(5, 5, 3)
|
|
|
|
def forward(self, x):
|
|
return self.conv_mod(x)
|
|
|
|
test_mod_3d = ConvTest3d()
|
|
traced_3d = symbolic_trace(test_mod_3d)
|
|
x_3d = torch.randn(5, 5, 224, 224, 15)
|
|
shape_prop.ShapeProp(traced_3d).propagate(x_3d)
|
|
assert all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
|
|
for node in traced_3d.graph.nodes)
|
|
|
|
x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
|
|
traced_3d.to(memory_format=torch.channels_last_3d)
|
|
shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
|
|
for node in traced_3d.graph.nodes:
|
|
# NB: the implementation of conv may not preserve the memory format,
|
|
# unfortunately. The best we can do is just check that the placeholder
|
|
# node is channels-last
|
|
if node.op in {'placeholder'}:
|
|
self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
|
|
|
|
def test_nn_module_stack(self):
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv_mod = torch.nn.Conv2d(64, 64, (3, 3), padding=1, bias=False)
|
|
|
|
def forward(self, x):
|
|
return self.conv_mod(x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sub_mod = SubModule()
|
|
|
|
def forward(self, x):
|
|
return self.sub_mod(x)
|
|
|
|
m = MyModule()
|
|
gm = torch.fx.symbolic_trace(m)
|
|
|
|
mod_stack = {}
|
|
expected_stack = [('sub_mod', ('sub_mod', type(m.sub_mod))),
|
|
('sub_mod.conv_mod', ('sub_mod.conv_mod', type(m.sub_mod.conv_mod)))]
|
|
for node in gm.graph.nodes:
|
|
mod_stack = node.meta.get('nn_module_stack', {})
|
|
if mod_stack:
|
|
break
|
|
stack_list = list(mod_stack.items())
|
|
self.assertEqual(stack_list, expected_stack)
|
|
|
|
def test_transformer_preserves_nn_module_stack_for_get_attr(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.ones(1, 1))
|
|
|
|
def forward(self, x):
|
|
return self.weight + x
|
|
|
|
tracer = torch.fx.Tracer()
|
|
graph = tracer.trace(M())
|
|
gm = GraphModule(tracer.root, graph)
|
|
for node in gm.graph.nodes:
|
|
if node.op == 'get_attr':
|
|
node.meta["nn_module_stack"] = "self"
|
|
node.meta["stack_trace"] = "stack_trace"
|
|
node.meta["source_fn_stack"] = "source_fn_stack"
|
|
new_gm = Transformer(gm).transform()
|
|
for node in new_gm.graph.nodes:
|
|
if node.op == 'get_attr':
|
|
self.assertEqual(node.meta["nn_module_stack"], "self")
|
|
self.assertEqual(node.meta["stack_trace"], "stack_trace")
|
|
self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
|
|
|
|
def test_interpreter(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
|
|
m = MyModule()
|
|
gm = torch.fx.symbolic_trace(m)
|
|
|
|
interpreter = Interpreter(gm)
|
|
input = torch.randn(3, 4)
|
|
self.assertEqual(interpreter.run(input), gm(input))
|
|
self.assertEqual(interpreter.run(input), m(input))
|
|
|
|
def test_interpreter_other_graph(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
|
|
m = MyModule()
|
|
gm = torch.fx.symbolic_trace(m)
|
|
|
|
interpreter = Interpreter(gm, graph=gm.graph)
|
|
input = torch.randn(3, 4)
|
|
self.assertEqual(interpreter.run(input), gm(input))
|
|
self.assertEqual(interpreter.run(input), m(input))
|
|
|
|
def test_interpreter_run_node_override(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
|
|
m = MyModule()
|
|
gm = torch.fx.symbolic_trace(m)
|
|
|
|
class RunNodeInterpreter(Interpreter):
|
|
def __init__(self, module):
|
|
super().__init__(module)
|
|
|
|
def run_node(self, n : Node) -> Any:
|
|
result = super().run_node(n)
|
|
n.cached_value = result
|
|
return result
|
|
|
|
input = torch.randn(3, 4)
|
|
RunNodeInterpreter(gm).run(input)
|
|
for node in gm.graph.nodes:
|
|
assert hasattr(node, 'cached_value')
|
|
|
|
def test_interpreter_onthefly_swap(self):
|
|
|
|
def fn(x):
|
|
return torch.sigmoid(x).neg()
|
|
|
|
gm = torch.fx.symbolic_trace(fn)
|
|
|
|
class NegSigmSwapInterpreter(Interpreter):
|
|
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
|
|
if target == torch.sigmoid:
|
|
return torch.neg(*args, **kwargs)
|
|
return super().call_function(n) # noqa: F821
|
|
|
|
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
|
|
if target == 'neg':
|
|
call_self, *args_tail = args
|
|
return call_self.sigmoid(*args_tail, **kwargs)
|
|
return super().call_method(n) # noqa: F821
|
|
|
|
input = torch.randn(3, 4)
|
|
result = NegSigmSwapInterpreter(gm).run(input)
|
|
self.assertEqual(result, torch.neg(input).sigmoid())
|
|
|
|
def test_interpreter_partial_eval(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
|
|
gm = torch.fx.symbolic_trace(MyModule())
|
|
interp = Interpreter(gm)
|
|
env = {}
|
|
for node in gm.graph.nodes:
|
|
if node.op == 'call_module' and node.target == 'linear':
|
|
env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
|
|
break
|
|
assert len(env) == 1
|
|
x = torch.randn(3, 4)
|
|
result = interp.run(x, initial_env=env)
|
|
self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
|
|
|
|
def test_interpreter_star_args(self):
|
|
def with_star_args(x, *args):
|
|
return x + args[0]
|
|
|
|
gm = torch.fx.symbolic_trace(with_star_args)
|
|
interp = Interpreter(gm)
|
|
result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
|
|
self.assertEqual(result, torch.ones(3, 4) * 2.0)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_interpreter_noop_resnet18(self):
|
|
rn18 = torchvision_models.resnet18()
|
|
transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
|
|
inp = torch.randn(5, 3, 224, 224)
|
|
self.assertEqual(transformed(inp), rn18(inp))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_interpreter_gc_values(self):
|
|
rn18 = torchvision_models.resnet18()
|
|
interp = Interpreter(symbolic_trace(rn18))
|
|
inp = torch.rand(5, 3, 224, 224)
|
|
out = interp.run(inp)
|
|
env_key_names = {n.name for n in interp.env.keys()}
|
|
self.assertEqual(env_key_names, {'output'})
|
|
|
|
def test_interpreter_default_args(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y=3.14159):
|
|
return x + y
|
|
|
|
model = Model()
|
|
gm = torch.fx.symbolic_trace(model)
|
|
|
|
interp = Interpreter(gm)
|
|
x = torch.randn(5, 3)
|
|
out = interp.run(x)
|
|
torch.testing.assert_close(out, x + 3.14159)
|
|
|
|
def test_interpreter_not_enough_args(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
model = Model()
|
|
gm = torch.fx.symbolic_trace(model)
|
|
|
|
interp = Interpreter(gm)
|
|
x = torch.randn(5, 3)
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
'Expected positional argument for parameter y, but one was not passed in'):
|
|
out = interp.run(x)
|
|
|
|
def test_transformer_noop(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
|
|
m = MyModule()
|
|
gm = torch.fx.symbolic_trace(m)
|
|
|
|
new_gm = Transformer(gm).transform()
|
|
|
|
input = torch.randn(3, 4)
|
|
self.assertEqual(new_gm(input), gm(input))
|
|
|
|
def test_transformer_op_swap(self):
|
|
|
|
def fn(x):
|
|
return torch.sigmoid(x).neg()
|
|
|
|
gm = torch.fx.symbolic_trace(fn)
|
|
|
|
class NegSigmSwapXformer(Transformer):
|
|
def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
|
|
if target == torch.sigmoid:
|
|
return torch.neg(*args, **kwargs)
|
|
return super().call_function(n) # noqa: F821
|
|
|
|
def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
|
|
if target == 'neg':
|
|
call_self, *args_tail = args
|
|
return call_self.sigmoid(*args_tail, **kwargs)
|
|
return super().call_method(n) # noqa: F821
|
|
|
|
transformed = NegSigmSwapXformer(gm).transform()
|
|
input = torch.randn(3, 4)
|
|
self.assertEqual(transformed(input), torch.neg(input).sigmoid())
|
|
|
|
def test_transformer_multi_outputs(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x):
|
|
x = x + self.param
|
|
out = self.linear(x)
|
|
return x, out
|
|
|
|
m = MyModule()
|
|
gm = torch.fx.symbolic_trace(m)
|
|
|
|
new_gm = Transformer(gm).transform()
|
|
|
|
input = torch.randn(3, 4)
|
|
self.assertEqual(new_gm(input), gm(input))
|
|
|
|
def test_fn_type_annotations(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
|
|
return {'a': p.x + p.y + z + i}
|
|
|
|
foo_scripted = torch.jit.script(Foo())
|
|
foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
|
|
|
|
fxed = symbolic_trace(Foo())
|
|
fxed_scripted = torch.jit.script(fxed)
|
|
fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
|
|
|
|
def test_fn_type_annotation_empty(self):
|
|
def forward(a : List[torch.Tensor]):
|
|
return a[0]
|
|
torch.jit.script(symbolic_trace(forward))
|
|
|
|
def test_wrapped_method(self):
|
|
def wrap_with_relu(fn):
|
|
@functools.wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
return torch.relu(fn(*args, **kwargs))
|
|
return wrapper
|
|
|
|
class Foo(torch.nn.Module):
|
|
@wrap_with_relu
|
|
def forward(self, x, w):
|
|
return torch.matmul(x, w)
|
|
|
|
f = Foo()
|
|
traced = symbolic_trace(f)
|
|
x, w = torch.rand(3, 4), torch.rand(4, 4)
|
|
self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
|
|
|
|
def test_empty_graph_codegen(self):
|
|
graph = torch.fx.Graph()
|
|
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
self.assertEqual(gm(), None)
|
|
|
|
def test_sequential(self):
|
|
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
|
|
gm = torch.fx.symbolic_trace(m)
|
|
gm_copy = copy.deepcopy(gm)
|
|
|
|
def test_ctx_mgr(self):
|
|
@contextlib.contextmanager
|
|
def do_nothing():
|
|
yield
|
|
|
|
class M(torch.nn.Module):
|
|
@do_nothing()
|
|
def forward(self, x):
|
|
return torch.relu(x)
|
|
|
|
m = M()
|
|
self.checkGraphModule(m, (torch.rand(3, 4),))
|
|
|
|
def test_typename_print(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
|
|
type_expr=List[float])
|
|
output : torch.fx.Node = graph.output(b)
|
|
|
|
self.assertTrue('typing.List[float]' in str(graph))
|
|
|
|
def test_layout(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
|
|
|
|
traced = symbolic_trace(M())
|
|
x = torch.rand(5, 9, 3, 4)
|
|
self.assertEqual(traced(x), torch.zeros_like(x))
|
|
|
|
def test_ellipsis(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y[:, 1:10, ...]
|
|
|
|
traced = symbolic_trace(M())
|
|
x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
|
|
self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
|
|
|
|
def test_inf_nan(self):
|
|
class FooMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + float('inf'), x + float('-inf'), x + float('nan')
|
|
|
|
fm = FooMod()
|
|
self.checkGraphModule(fm, (torch.rand(3, 4),))
|
|
|
|
def test_inf_nan_kwds(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
|
|
c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
|
|
graph.output((b, c))
|
|
|
|
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
x = torch.rand(3, 4)
|
|
self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
|
|
|
|
def test_deepcopy_recursion_depth(self):
|
|
depth = sys.getrecursionlimit() + 20
|
|
|
|
g = torch.fx.Graph()
|
|
x = g.placeholder('x')
|
|
for i in range(depth):
|
|
x = g.call_function(torch.relu, (x,))
|
|
g.output(x)
|
|
|
|
copied_graph = copy.deepcopy(g)
|
|
|
|
val_map = {}
|
|
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
|
|
val_map[orig_node] = new_node
|
|
|
|
for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
|
|
orig_users = set(orig_node.users.keys())
|
|
orig_users_equiv = {val_map[u] for u in orig_users}
|
|
new_users = set(new_node.users.keys())
|
|
self.assertEqual(orig_users_equiv, new_users)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_replace_uses(self):
|
|
rn18 = torchvision_models.resnet18()
|
|
|
|
class LowerReluTracer(torch.fx.Tracer):
|
|
def is_leaf_module(self, m : torch.nn.Module, qualname : str):
|
|
if isinstance(m, torch.nn.ReLU):
|
|
return False
|
|
return super().is_leaf_module(m, qualname)
|
|
|
|
rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
|
|
|
|
to_erase = []
|
|
for node in rn18_traced.graph.nodes:
|
|
if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
|
|
kwargs = node.kwargs.copy()
|
|
# Neg doesn't have in-place
|
|
kwargs.pop('inplace')
|
|
with rn18_traced.graph.inserting_before(node):
|
|
new_node = rn18_traced.graph.call_function(
|
|
the_function=torch.neg, args=node.args, kwargs=node.kwargs)
|
|
node.replace_all_uses_with(replace_with=new_node)
|
|
to_erase.append(node)
|
|
|
|
for node in to_erase:
|
|
rn18_traced.graph.erase_node(node)
|
|
|
|
def test_replace_input(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
y : torch.fx.Node = graph.create_node('placeholder', 'y')
|
|
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
|
|
output : torch.fx.Node = graph.output(b)
|
|
|
|
b.replace_input_with(x, y)
|
|
|
|
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
|
|
input_x = torch.randn(33, 44)
|
|
input_y = torch.randn(11, 22)
|
|
self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
|
|
|
|
def test_insertion_point(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
|
|
output : torch.fx.Node = graph.output(b)
|
|
|
|
with graph.inserting_before(b):
|
|
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
|
|
_, *relu_args = b.args
|
|
b.args = (neg, *relu_args)
|
|
|
|
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
|
|
input = torch.randn(33, 44)
|
|
self.assertEqual(gm(input), torch.relu(torch.neg(input)))
|
|
|
|
def test_update_args_api(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
y : torch.fx.Node = graph.create_node('placeholder', 'y')
|
|
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
|
|
output : torch.fx.Node = graph.output(b)
|
|
|
|
orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
|
|
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
|
|
|
|
b.update_arg(0, y)
|
|
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
|
|
|
|
def test_update_kwargs_api(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
y : torch.fx.Node = graph.create_node('placeholder', 'y')
|
|
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
|
|
output : torch.fx.Node = graph.output(b)
|
|
|
|
orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
|
|
self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
|
|
|
|
b.update_kwarg('input', y)
|
|
new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
|
|
|
|
def test_immutable_list_pytree_ops(self):
|
|
rand_tensor = torch.randn(5, 3)
|
|
l = immutable_list([3, [rand_tensor, 42]])
|
|
|
|
flattened, spec = pytree.tree_flatten(l)
|
|
assert flattened == [3, rand_tensor, 42]
|
|
|
|
unflattened = pytree.tree_unflatten(flattened, spec)
|
|
assert unflattened == l
|
|
assert isinstance(unflattened, immutable_list)
|
|
|
|
def test_immutable_dict_pytree_ops(self):
|
|
rand_tensor = torch.randn(5, 3)
|
|
d = immutable_dict({'a': 3, 'b': [rand_tensor, 42]})
|
|
|
|
flattened, spec = pytree.tree_flatten(d)
|
|
assert flattened == [3, rand_tensor, 42]
|
|
|
|
unflattened = pytree.tree_unflatten(flattened, spec)
|
|
assert unflattened == d
|
|
assert isinstance(unflattened, immutable_dict)
|
|
|
|
def test_move_before(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
|
|
output : torch.fx.Node = graph.output(b)
|
|
|
|
neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
|
|
_, *relu_args = b.args
|
|
b.args = (neg, *relu_args)
|
|
b.prepend(neg)
|
|
|
|
gm = torch.fx.GraphModule(torch.nn.Module(), graph)
|
|
|
|
input = torch.randn(33, 44)
|
|
self.assertEqual(gm(input), torch.relu(torch.neg(input)))
|
|
|
|
def test_prepend_self(self):
|
|
graph : torch.fx.Graph = torch.fx.Graph()
|
|
x : torch.fx.Node = graph.create_node('placeholder', 'x')
|
|
b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
|
|
output : torch.fx.Node = graph.output(b)
|
|
|
|
b.prepend(b)
|
|
x.append(b)
|
|
self.assertEqual(len(graph.nodes), 3)
|
|
|
|
def test_erase_node_error(self):
|
|
st = SimpleTest()
|
|
traced = symbolic_trace(st)
|
|
|
|
for node in traced.graph.nodes:
|
|
# Test deleting with uses both in another Node and at the output
|
|
if node.target in [operator.add, torch.relu]:
|
|
with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
|
|
traced.graph.erase_node(node)
|
|
|
|
def test_copy_it(self):
|
|
d = immutable_dict([(3, 4), (5, 6)])
|
|
l = immutable_list([(3, 4), (5, 6)])
|
|
|
|
self.assertEqual(d, deepcopy(d))
|
|
self.assertEqual(l, deepcopy(l))
|
|
|
|
def test_get_torch_func_signature(self):
|
|
for key in dir(torch):
|
|
obj = getattr(torch, key)
|
|
if callable(obj):
|
|
schemas = get_signature_for_torch_op(obj)
|
|
|
|
def test_find_uses(self):
|
|
graph = torch.fx.Graph()
|
|
x = torch.fx.Proxy(graph.placeholder('x'))
|
|
|
|
y = torch.relu(x)
|
|
z = x + x
|
|
u = torch.neg(x)
|
|
graph.output((y + z + u).node)
|
|
graph.lint()
|
|
|
|
users_of_x = x.node.users
|
|
self.assertEqual(len(users_of_x), 3)
|
|
expected_ops = {'relu', 'add', 'neg'}
|
|
for use in users_of_x:
|
|
assert any(use.name.startswith(prefix) for prefix in expected_ops)
|
|
|
|
def test_inline_graph(self):
|
|
class InlineInto(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.relu(x)
|
|
|
|
class ToInline(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.neg(x)
|
|
|
|
inline_into = symbolic_trace(InlineInto())
|
|
to_inline = symbolic_trace(ToInline())
|
|
|
|
combined_graph = torch.fx.Graph()
|
|
output_node = combined_graph.graph_copy(inline_into.graph, {})
|
|
|
|
input_node = next(iter(to_inline.graph.nodes))
|
|
assert input_node and input_node.op == 'placeholder'
|
|
|
|
val_map = {input_node : output_node}
|
|
output = combined_graph.graph_copy(to_inline.graph, val_map)
|
|
combined_graph.output(output)
|
|
|
|
combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
|
|
|
|
input = torch.rand(3, 4)
|
|
self.assertEqual(combined_module(input), input.relu().neg())
|
|
|
|
def test_multi_insert_point(self):
|
|
graph = torch.fx.Graph()
|
|
x = torch.fx.Proxy(graph.placeholder('x'))
|
|
relu = torch.relu(x)
|
|
|
|
with graph.inserting_before(relu.node):
|
|
y = torch.neg(x)
|
|
z = torch.tanh(y)
|
|
|
|
graph.output((relu.node, z.node))
|
|
graph.lint()
|
|
|
|
expected_ops = ['x', 'neg', 'tanh', 'relu']
|
|
for node, expected in zip(graph.nodes, expected_ops):
|
|
assert expected in node.name
|
|
|
|
def test_reassign_args_kwargs_uses(self):
|
|
graph = torch.fx.Graph()
|
|
x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
|
|
z = x + y
|
|
zed = z + z + z
|
|
graph.output(zed.node)
|
|
graph.lint()
|
|
|
|
# zed = z + z + z -> zed = z + z + x
|
|
zed.node.args = (zed.node.args[0], x.node)
|
|
self.assertEqual(list(x.node.users.keys()), [z.node, zed.node])
|
|
|
|
# z = x + y -> z = y + y
|
|
z.node.args = (y.node, y.node)
|
|
self.assertEqual(list(x.node.users.keys()), [zed.node])
|
|
|
|
def test_trace_function(self):
|
|
def foo(x, y):
|
|
return torch.relu(x) + y
|
|
|
|
x, y = torch.randn(3, 4), torch.randn(3, 4)
|
|
self.checkGraphModule(foo, (x, y))
|
|
|
|
def test_trace_return_dataclass(self):
|
|
"""
|
|
Test case for Module that return dataclass
|
|
"""
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class MyOutput:
|
|
foo: torch.Tensor
|
|
bar: torch.Tensor
|
|
|
|
class ModuleReturnDataclass(torch.nn.Module):
|
|
def forward(self, d : torch.Tensor):
|
|
return MyOutput(foo=d + d, bar=d * 3)
|
|
|
|
module = ModuleReturnDataclass()
|
|
traced_graph = symbolic_trace(module).graph
|
|
print(traced_graph)
|
|
|
|
gm = GraphModule(module, traced_graph)
|
|
x = torch.rand(1)
|
|
|
|
self.assertEqual(module(x), gm(x))
|
|
|
|
def test_trace_return_dataclass_nested(self):
|
|
"""
|
|
Test case for Module that return dataclass
|
|
"""
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class MyOutput:
|
|
foo: torch.Tensor
|
|
bar: torch.Tensor
|
|
|
|
class ModuleReturnDataclass(torch.nn.Module):
|
|
def forward(self, d : torch.Tensor):
|
|
return MyOutput(foo=d + d, bar=d * 3)
|
|
|
|
class CallsModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m = ModuleReturnDataclass()
|
|
|
|
def forward(self, x):
|
|
tmp = self.m(x)
|
|
return MyOutput(foo=tmp.foo, bar=tmp.bar)
|
|
|
|
module = CallsModule()
|
|
traced_graph = symbolic_trace(module).graph
|
|
print(traced_graph)
|
|
|
|
gm = GraphModule(module, traced_graph)
|
|
x = torch.rand(1)
|
|
|
|
self.assertEqual(module(x), gm(x))
|
|
|
|
def test_trace_return_namedtuple(self):
|
|
"""
|
|
Test case for Module that return namedtuple
|
|
"""
|
|
class MyOutput(NamedTuple):
|
|
foo: torch.Tensor
|
|
bar: torch.Tensor
|
|
|
|
class ModuleReturnNamedTuple(torch.nn.Module):
|
|
def forward(self, d : torch.Tensor):
|
|
return MyOutput(foo=d, bar=d)
|
|
|
|
module = ModuleReturnNamedTuple()
|
|
|
|
traced_graph = symbolic_trace(module).graph
|
|
print(traced_graph)
|
|
|
|
gm = GraphModule(module, traced_graph)
|
|
x = torch.rand(1)
|
|
|
|
self.assertEqual(module(x), gm(x))
|
|
|
|
def test_trace_dict_int_keys(self):
|
|
class ModWithDictArg(torch.nn.Module):
|
|
def forward(self, d : Dict[int, torch.Tensor]):
|
|
return d[42]
|
|
|
|
class CallsModWithDict(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m = ModWithDictArg()
|
|
|
|
def forward(self, x):
|
|
return self.m({42: x})
|
|
|
|
class MyTracer(torch.fx.Tracer):
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
|
return isinstance(m, ModWithDictArg)
|
|
|
|
traced_graph = MyTracer().trace(CallsModWithDict())
|
|
|
|
def test_trace_dict_proxy_keys(self):
|
|
class ModWithDictArg(torch.nn.Module):
|
|
def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
|
|
return d[42]
|
|
|
|
class CallsModWithDict(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m = ModWithDictArg()
|
|
|
|
def forward(self, x):
|
|
return self.m({x: x})
|
|
|
|
class MyTracer(torch.fx.Tracer):
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
|
return isinstance(m, ModWithDictArg)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
|
|
traced_graph = MyTracer().trace(CallsModWithDict())
|
|
|
|
def test_module_deepcopy_edit_nodes(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.relu(x)
|
|
|
|
traced1 = symbolic_trace(Foo())
|
|
copied = copy.deepcopy(traced1)
|
|
|
|
for node in copied.graph.nodes:
|
|
if node.target == torch.relu:
|
|
node.target = torch.neg
|
|
|
|
copied.recompile()
|
|
traced1.recompile()
|
|
|
|
x = torch.randn(15, 15)
|
|
torch.testing.assert_close(traced1(x), torch.relu(x))
|
|
torch.testing.assert_close(copied(x), torch.neg(x))
|
|
|
|
def test_direct_param_use(self):
|
|
class TransposeTest(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.b = torch.nn.Parameter(torch.rand(4, 3))
|
|
|
|
def forward(self, x):
|
|
return self.b
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = TransposeTest()
|
|
|
|
def forward(self, x):
|
|
return self.a.b, self.a.b.t(), self.a.b.view(12)
|
|
|
|
traced = torch.fx.symbolic_trace(Foo())
|
|
assert all('constant' not in node.target for node in traced.graph.nodes)
|
|
|
|
def test_single_default_arg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, y=1):
|
|
return y
|
|
|
|
m = M()
|
|
self.checkGraphModule(m, ())
|
|
self.checkGraphModule(m, (3,))
|
|
|
|
def test_multiple_default_args(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, y=1, z=2):
|
|
return y + z
|
|
|
|
m = M()
|
|
self.checkGraphModule(m, ())
|
|
self.checkGraphModule(m, (3,))
|
|
self.checkGraphModule(m, (3, 4))
|
|
|
|
def test_regular_and_default_args(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y=1):
|
|
return x + y
|
|
|
|
m = M()
|
|
self.checkGraphModule(m, (2,))
|
|
self.checkGraphModule(m, (2, 3))
|
|
|
|
def test_string_literal_return(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self):
|
|
return "foo"
|
|
|
|
m = M()
|
|
self.checkGraphModule(m, ())
|
|
|
|
def test_namedtuple_return_qualname(self):
|
|
class NamedTupReturn(torch.nn.Module):
|
|
def forward(self, x):
|
|
return MyNamedTup(x, x)
|
|
|
|
traced = symbolic_trace(NamedTupReturn())
|
|
input = torch.rand(3, 4)
|
|
self.assertEqual(traced(input), MyNamedTup(input, input))
|
|
|
|
def test_update_args_kwargs_yells_at_you(self):
|
|
symtraced = symbolic_trace(SimpleTest())
|
|
node = next(iter(symtraced.graph.nodes))
|
|
with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
|
|
node.__update_args_kwargs((), {})
|
|
|
|
def test_torchbind_class_attribute_in_fx(self):
|
|
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
|
|
self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
|
|
|
|
class FooBar1234(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
|
|
|
|
def forward(self):
|
|
return self.f.top()
|
|
|
|
m = FooBar1234()
|
|
self.checkGraphModule(m, ())
|
|
|
|
def test_torchbind_class_attribute_in_fx_tensor_arg(self):
|
|
if IS_FBCODE or IS_WINDOWS or IS_MACOS:
|
|
self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
|
|
|
|
class FooBar2341(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.f = torch.classes._TorchScriptTesting._ReLUClass()
|
|
|
|
def forward(self, x):
|
|
return self.f.run(x)
|
|
|
|
m = FooBar2341()
|
|
|
|
traced = symbolic_trace(m)
|
|
input = torch.randn(3, 4)
|
|
self.assertEqual(traced(input), m(input))
|
|
|
|
self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
|
|
|
|
def test_script_method_trace(self):
|
|
class Scripted(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.relu(x)
|
|
|
|
class Holder(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.s = torch.jit.script(Scripted())
|
|
|
|
def forward(self, x):
|
|
return self.s(x)
|
|
|
|
h = Holder()
|
|
traced = symbolic_trace(h)
|
|
input = torch.randn(3, 4)
|
|
self.assertEqual(traced(input), h(input))
|
|
|
|
self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
|
|
|
|
def test_namedtuple_return_trace(self):
|
|
class NamedTupReturn(torch.nn.Module):
|
|
def forward(self, x):
|
|
return Pair(x, x)
|
|
|
|
traced = symbolic_trace(NamedTupReturn())
|
|
input = torch.rand(3, 4)
|
|
self.assertEqual(traced(input), Pair(input, input))
|
|
|
|
def test_named_tuple_inlined(self):
|
|
class NamedTupMod(torch.nn.Module):
|
|
def forward(self, inp):
|
|
return wrapped_named_tup(Pair(inp, 1.2), p2=Pair(3.4, inp))
|
|
|
|
m = NamedTupMod()
|
|
input = torch.rand(3, 4)
|
|
ref = m(input)
|
|
traced = symbolic_trace(m)
|
|
|
|
res = traced(input)
|
|
self.assertEqual(ref, res)
|
|
|
|
# Check Pair NamedTuple works when inlined into the function call.
|
|
ph = call_func = None
|
|
for node in traced.graph.nodes:
|
|
if node.op == "placeholder":
|
|
ph = node
|
|
elif node.op == "call_function" and node.target == wrapped_named_tup:
|
|
node.update_arg(0, Pair(ph, 1.2))
|
|
node.update_kwarg("p2", Pair(3.4, ph))
|
|
call_func = node
|
|
break
|
|
self.assertTrue(call_func is not None)
|
|
self.assertTrue(isinstance(call_func.args[0], Pair))
|
|
self.assertTrue(isinstance(call_func.kwargs["p2"], Pair))
|
|
self.assertEqual(_format_arg(call_func.args[0]), "Pair(x=%inp, y=1.2)")
|
|
self.assertEqual(_format_arg(call_func.kwargs["p2"]), "Pair(x=3.4, y=%inp)")
|
|
|
|
traced.graph.eliminate_dead_code()
|
|
traced.recompile()
|
|
res = traced(input)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_return_type_exists(self):
|
|
class ReturnTypeModule(torch.nn.Module):
|
|
def other(self, x: List[str]) -> List[str]:
|
|
return x
|
|
|
|
def forward(self, x: List[str]) -> List[str]:
|
|
return self.other(x)
|
|
|
|
traced = symbolic_trace(ReturnTypeModule())
|
|
self.assertIn("-> typing_List[str]", traced._code)
|
|
scripted = torch.jit.script(traced)
|
|
self.assertIn("-> List[str]", scripted.code)
|
|
|
|
def getitem_inner(self):
|
|
class GetItemBase(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.pe = torch.nn.Buffer(torch.randn(8, 8))
|
|
|
|
class GetItem1(GetItemBase):
|
|
def forward(self, x):
|
|
return self.pe[:, :x.size(0)]
|
|
|
|
class GetItem2(GetItemBase):
|
|
def forward(self, x):
|
|
return self.pe[x.size(0)]
|
|
|
|
class GetItem3(GetItemBase):
|
|
def forward(self, x):
|
|
return self.pe[4] # fx creates `self._tensor_constant0` here
|
|
|
|
self.checkGraphModule(GetItem1(), [torch.zeros(4)])
|
|
self.checkGraphModule(GetItem2(), [torch.zeros(4)])
|
|
self.checkGraphModule(GetItem3(), [torch.zeros(4)])
|
|
|
|
@unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
|
|
"Will be checked in test_getitem_subproc")
|
|
def test_getitem(self):
|
|
self.getitem_inner()
|
|
|
|
def test_getitem_subproc(self):
|
|
# need to run this test in a subproc to work around:
|
|
# https://github.com/pytorch/pytorch/issues/50710
|
|
proc = Process(target=run_getitem_target)
|
|
proc.start()
|
|
proc.join()
|
|
self.assertEqual(proc.exitcode, 0)
|
|
|
|
def test_user_friendly_call_provenance_with_function(self):
|
|
def fn(x):
|
|
return wrapper_fn(x)
|
|
|
|
traced = torch.fx.symbolic_trace(fn)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
|
|
"being compiled since it was called"
|
|
" from 'fn.forward'"):
|
|
scripted = torch.jit.script(traced)
|
|
|
|
def test_user_friendly_call_provenance_with_module(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return wrapper_fn(x)
|
|
|
|
traced = torch.fx.symbolic_trace(M())
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
|
|
"being compiled since it was called"
|
|
" from 'M.forward'"):
|
|
scripted = torch.jit.script(traced)
|
|
|
|
def test_snake_case(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.activations = torch.nn.ModuleDict([
|
|
["snake_case", torch.nn.ReLU()],
|
|
["PascalCase", torch.nn.LeakyReLU()],
|
|
["ALL_CAPS", torch.nn.PReLU()]
|
|
])
|
|
|
|
def forward(self, x):
|
|
a = self.activations["snake_case"](x)
|
|
b = self.activations["PascalCase"](x)
|
|
c = self.activations["ALL_CAPS"](x)
|
|
return a, b, c
|
|
|
|
traced = symbolic_trace(M())
|
|
|
|
check = [
|
|
("activations_snake_case", "activations.snake_case"),
|
|
("activations_pascal_case", "activations.PascalCase"),
|
|
("activations_all_caps", "activations.ALL_CAPS")
|
|
]
|
|
|
|
i = 0
|
|
for node in traced.graph.nodes:
|
|
if node.op == "placeholder" or node.op == "output":
|
|
continue
|
|
name = check[i][0]
|
|
target = check[i][1]
|
|
self.assertEqual(name, node.name)
|
|
self.assertEqual(target, node.target)
|
|
i += 1
|
|
self.assertEqual(i, 3)
|
|
|
|
def test_no_mutation(self):
|
|
from torch.fx.immutable_collections import immutable_list
|
|
x = immutable_list([3, 4])
|
|
with self.assertRaisesRegex(NotImplementedError, "new_args"):
|
|
x[0] = 4
|
|
|
|
def test_partial_trace(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
if y:
|
|
return 2 * x
|
|
else:
|
|
return x
|
|
mod = Foo()
|
|
mod_true = symbolic_trace(mod, concrete_args={'y': True})
|
|
mod_false = symbolic_trace(mod, concrete_args={'y': False})
|
|
self.assertEqual(mod_true(3, True), 6)
|
|
print(mod_true.code)
|
|
assert any(i.target == torch._assert for i in mod_true.graph.nodes)
|
|
with self.assertRaises(AssertionError):
|
|
mod_true(3, False)
|
|
self.assertEqual(mod_false(3, False), 3)
|
|
with self.assertRaises(AssertionError):
|
|
mod_false(3, True)
|
|
|
|
def f_higher(a, f):
|
|
return f(a)
|
|
|
|
nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
|
|
self.assertEqual(nf(3, lambda x: x * 2), 6)
|
|
|
|
def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.W = torch.nn.Parameter(torch.randn(5))
|
|
|
|
def forward(self, x):
|
|
return torch.dot(self.W, x)
|
|
|
|
traced = torch.fx.symbolic_trace(M())
|
|
|
|
out = [n for n in traced.graph.nodes if n.op == "output"][-1]
|
|
with traced.graph.inserting_before(out):
|
|
relu_out = traced.graph.call_method(method_name='relu',
|
|
args=(out.args[0],))
|
|
out.args = (relu_out,)
|
|
|
|
traced.recompile()
|
|
|
|
with self.capture_stderr() as captured:
|
|
with self.assertRaises(TypeError):
|
|
traced(5)
|
|
|
|
self.assertRegex(captured[0],
|
|
r"Call using an FX-traced Module, line .* of the "
|
|
r"traced Module's generated forward function:")
|
|
|
|
def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(3, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
traced = torch.fx.symbolic_trace(M())
|
|
|
|
# Do not change this to `capture_stderr` or another context
|
|
# manager without ensuring that the output is as expected
|
|
try:
|
|
traced(torch.rand(5, 5))
|
|
except RuntimeError:
|
|
captured = traceback.format_exc()
|
|
|
|
self.assertNotRegex(captured,
|
|
r"Call using an FX-traced Module, line .* of the "
|
|
r"traced Module's generated forward function:")
|
|
|
|
def test_graph_module_replicate_for_dp(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.relu(x)
|
|
|
|
gm = torch.fx.symbolic_trace(Foo())
|
|
|
|
x = torch.randn(5, 3)
|
|
out = gm(x)
|
|
|
|
replica = gm._replicate_for_data_parallel()
|
|
out_replica = replica(x)
|
|
|
|
torch.testing.assert_close(out_replica, out)
|
|
|
|
def test_ast_rewriter_rewrites_assert(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: int, z: int):
|
|
assert y == z
|
|
return torch.add(x, x)
|
|
|
|
ast_rewriter = RewritingTracer()
|
|
graph = ast_rewriter.trace(M())
|
|
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
|
|
|
traced.graph.lint()
|
|
|
|
def test_ast_rewriter_rewrites_assert_with_message(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: int, z: int):
|
|
assert y == z, "msg"
|
|
return torch.add(x, x)
|
|
|
|
ast_rewriter = RewritingTracer()
|
|
graph = ast_rewriter.trace(M())
|
|
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
|
|
|
traced.graph.lint()
|
|
|
|
def test_throw_out_variant(self):
|
|
def foo(x):
|
|
y = torch.rand_like(x)
|
|
torch.sigmoid(x, out=y)
|
|
return y
|
|
|
|
class MyTracer(torch.fx.Tracer):
|
|
check_mutable_operations = True
|
|
|
|
tracer = MyTracer()
|
|
with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
|
|
traced_graph = tracer.trace(foo)
|
|
|
|
def test_ast_rewriter_reassigns_submodules(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bn = torch.nn.BatchNorm2d(100)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return torch.add(x, x)
|
|
|
|
ast_rewriter = RewritingTracer()
|
|
graph = ast_rewriter.trace(M())
|
|
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
|
|
|
traced.graph.lint()
|
|
|
|
def test_ast_rewriter_wrap(self):
|
|
self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
|
|
|
|
def to_trace(y):
|
|
return (
|
|
a_lifted_leaf((4, y), 3)
|
|
+ a_lifted_leaf((3, 4), 5)
|
|
+ a_lifted_leaf((y, y), y)
|
|
)
|
|
|
|
ast_rewriter = RewritingTracer()
|
|
graph = ast_rewriter.trace(to_trace)
|
|
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
|
|
|
self.assertIn("a_lifted_leaf", traced.code)
|
|
self.assertEqual(27, traced(2))
|
|
self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
|
|
|
|
def test_ast_rewriter_wrap_fn_directly(self):
|
|
self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
|
|
|
|
def to_trace(y):
|
|
return (
|
|
a_lifted_leaf2((4, y), 3)
|
|
+ a_lifted_leaf2((3, 4), 5)
|
|
+ a_lifted_leaf2((y, y), y)
|
|
)
|
|
|
|
ast_rewriter = RewritingTracer()
|
|
graph = ast_rewriter.trace(to_trace)
|
|
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
|
|
|
self.assertIn("a_lifted_leaf2", traced.code)
|
|
self.assertEqual(27, traced(2))
|
|
self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
|
|
|
|
def test_profiler_ranges_side_effect(self):
|
|
g = torch.fx.Graph()
|
|
handle = g.call_function(torch.ops.profiler._record_function_enter_new, ('test_range',))
|
|
g.call_function(torch.ops.profiler._record_function_exit, (handle,))
|
|
g.output(None)
|
|
|
|
found_targets = {}
|
|
for node in g.nodes:
|
|
if node.op == 'call_function':
|
|
found_targets.setdefault(node.target)
|
|
self.assertEqual(
|
|
list(found_targets.keys()),
|
|
[torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
|
|
)
|
|
|
|
g.eliminate_dead_code()
|
|
found_targets = {}
|
|
for node in g.nodes:
|
|
if node.op == 'call_function':
|
|
found_targets.setdefault(node.target)
|
|
self.assertEqual(
|
|
list(found_targets.keys()),
|
|
[torch.ops.profiler._record_function_enter_new, torch.ops.profiler._record_function_exit]
|
|
)
|
|
|
|
def test_ast_rewriter_wrapped_via_decorator(self):
|
|
class F(torch.nn.Module):
|
|
def forward(self, x):
|
|
return wrapped_via_decorator(x)
|
|
|
|
ast_rewriter = RewritingTracer()
|
|
graph = ast_rewriter.trace(F())
|
|
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
|
|
|
self.assertIn("wrapped_via_decorator", traced.code)
|
|
self.assertEqual(traced(0), 1)
|
|
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
|
|
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
|
|
|
|
def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
|
|
self.assertEqual(wrapped_via_decorator(0), 1)
|
|
|
|
def to_trace(y):
|
|
return wrapped_via_decorator(y)
|
|
|
|
ast_rewriter = RewritingTracer()
|
|
graph = ast_rewriter.trace(to_trace)
|
|
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
|
|
|
self.assertIn("wrapped_via_decorator", traced.code)
|
|
self.assertEqual(traced(0), 1)
|
|
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
|
|
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
|
|
|
|
transformed = torch.fx.Transformer(traced).transform()
|
|
self.assertIn("wrapped_via_decorator", transformed.code)
|
|
self.assertEqual(transformed(0), 1)
|
|
self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
|
|
self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
|
|
|
|
def test_ast_rewriter_wrap_with_submodule(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return wrapped_with_submodule(x, self.batchnorm1d)
|
|
|
|
ast_rewriter = RewritingTracer()
|
|
graph = ast_rewriter.trace(M())
|
|
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
|
|
|
self.assertIn("wrapped_with_submodule", traced.code)
|
|
|
|
input = torch.rand(3, 2)
|
|
ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
|
|
self.assertEqual(ref_batchnorm1d(input), traced(input))
|
|
|
|
def test_submodule_manipulation_API(self):
|
|
class C(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
|
|
self.param = torch.nn.Parameter(torch.rand(2, 3))
|
|
|
|
def forward(self, x):
|
|
return self.conv(torch.cat([self.param, x]))
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(100, 200)
|
|
self.buf = torch.nn.Buffer(torch.randn(2, 3))
|
|
self.net_c = C()
|
|
|
|
def forward(self, x):
|
|
return self.linear(torch.cat([self.buf, self.net_c(x)]))
|
|
|
|
class A(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.net_b = B()
|
|
self.param = torch.nn.Parameter(torch.rand(2, 3))
|
|
|
|
def forward(self, x):
|
|
return self.net_b(x) + self.param
|
|
|
|
a = symbolic_trace(A())
|
|
|
|
a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
|
|
|
|
conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
|
|
with a.graph.inserting_before(conv):
|
|
with warnings.catch_warnings(record=True) as w:
|
|
dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
|
|
args=conv.args)
|
|
self.assertEqual(len(w), 0)
|
|
|
|
conv.replace_all_uses_with(dropout)
|
|
a.graph.erase_node(conv)
|
|
a.recompile()
|
|
|
|
def module_exists(gm: GraphModule, path: str) -> bool:
|
|
return any(path == name for name, _ in gm.named_modules())
|
|
|
|
def parameter_exists(gm: GraphModule, path: str) -> bool:
|
|
return (any(path == name for name, _ in gm.named_parameters())
|
|
and any(path == name for name in gm.state_dict().keys()))
|
|
|
|
def buffer_exists(gm: GraphModule, path: str) -> bool:
|
|
return (any(path == name for name, _ in gm.named_buffers())
|
|
and any(path == name for name in gm.state_dict().keys()))
|
|
|
|
# Test that we added the "dropout" submodule
|
|
self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
|
|
|
|
# Test `get_submodule` with an added submodule
|
|
self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
|
|
|
|
# Test that the "conv" submodule is still there
|
|
self.assertTrue(module_exists(a, "net_b.net_c.conv"))
|
|
|
|
# Test `get_submodule` with an original module
|
|
self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
|
|
|
|
# Test that the "conv" node is NOT still there
|
|
conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
|
|
self.assertEqual(conv, [])
|
|
|
|
a.delete_submodule("net_b.net_c.conv")
|
|
|
|
# Test that the "conv" submodule is now gone
|
|
self.assertFalse(module_exists(a, "net_b.net_c.conv"))
|
|
|
|
# Test `get_submodule` with a deleted submodule
|
|
with self.assertRaisesRegex(AttributeError, "has no attribute "
|
|
"`conv`"):
|
|
self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
|
|
|
|
# Test `get_attr` warnings
|
|
cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
|
|
|
|
with a.graph.inserting_before(cat):
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
param = a.graph.get_attr(qualified_name="net_b.net_c.param")
|
|
self.assertEqual(len(w), 0)
|
|
|
|
with self.assertWarnsRegex(UserWarning, "Attempted to "
|
|
"insert a get_attr Node with no "
|
|
"underlying reference in the "
|
|
"owning GraphModule"):
|
|
bad_param = a.graph.get_attr(qualified_name="net_b.param")
|
|
a.graph.erase_node(bad_param)
|
|
|
|
cat.args = (*cat.args, param)
|
|
|
|
a.recompile()
|
|
|
|
a.graph.lint()
|
|
|
|
# Test `get_parameter`
|
|
a.get_parameter("net_b.net_c.param")
|
|
with self.assertRaisesRegex(AttributeError, "is not an "
|
|
"nn.Parameter"):
|
|
a.get_parameter("net_b.buf")
|
|
with self.assertRaisesRegex(AttributeError, "has no attribute "
|
|
"`param`"):
|
|
a.get_parameter("net_b.param")
|
|
|
|
# Test `get_buffer`
|
|
a.get_buffer("net_b.buf")
|
|
with self.assertRaisesRegex(AttributeError, "is not a "
|
|
"buffer"):
|
|
a.get_buffer("net_b.net_c.param")
|
|
with self.assertRaisesRegex(AttributeError, "has no attribute "
|
|
"`buf`"):
|
|
a.get_buffer("net_b.net_c.buf")
|
|
|
|
# Test non-nested attributes
|
|
a.get_submodule("")
|
|
a.get_parameter("param")
|
|
|
|
# Insert some unused submodules
|
|
a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
|
|
a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
|
|
a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
|
|
a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
|
|
|
|
# Garbage collection
|
|
a.delete_all_unused_submodules()
|
|
|
|
# Test that all the unused submodules are gone
|
|
self.assertFalse(module_exists(a, "net_b.embedding"))
|
|
self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
|
|
self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
|
|
self.assertFalse(module_exists(a, "batch_norm_2d"))
|
|
|
|
# Test that we didn't delete any unused Parameters or buffers
|
|
self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
|
|
self.assertTrue(buffer_exists(a, "net_b.buf"))
|
|
|
|
a.graph.lint()
|
|
|
|
def test_delete_unused_submodules_leaf(self):
|
|
class SubModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(10, 10)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submod = SubModule()
|
|
|
|
def forward(self, x):
|
|
x = self.submod(x)
|
|
return x
|
|
|
|
model = Model()
|
|
|
|
class MyCustomTracer(torch.fx.Tracer):
|
|
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
|
|
return module_qualified_name == "submod"
|
|
|
|
inputs = torch.randn(1, 10)
|
|
traced_graph = MyCustomTracer().trace(model)
|
|
gm2 = torch.fx.GraphModule(model, traced_graph)
|
|
gm2.delete_all_unused_submodules()
|
|
torch.testing.assert_close(gm2(inputs), model(inputs))
|
|
|
|
def test_fx_stateless(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.l1 = torch.nn.Linear(1, 1)
|
|
self.buffer = torch.nn.Buffer(torch.ones(1))
|
|
|
|
def forward(self, x):
|
|
return self.l1(x) + self.buffer
|
|
|
|
module = MockModule()
|
|
x = torch.rand((1, 1))
|
|
weight = torch.tensor([[1.0]], requires_grad=True)
|
|
bias = torch.tensor([0.0], requires_grad=True)
|
|
buffer = torch.tensor([0.0])
|
|
parameters = {'l1.weight': weight,
|
|
'l1.bias': bias,
|
|
'buffer': buffer}
|
|
fx_module = torch.fx.symbolic_trace(module)
|
|
res = torch.func.functional_call(fx_module, parameters, x)
|
|
res.backward()
|
|
self.assertIsNotNone(weight.grad)
|
|
self.assertIsNotNone(bias.grad)
|
|
self.assertIsNone(buffer.grad)
|
|
# Gradient was not calculated for the module stated and buffers
|
|
self.assertIsNone(module.l1.weight.grad)
|
|
self.assertIsNone(module.l1.bias.grad)
|
|
self.assertIsNone(module.buffer.grad)
|
|
|
|
def test_tracing_graphmodules_as_leaf_submodules(self):
|
|
class A(torch.nn.Module):
|
|
def forward(self, t):
|
|
return t + t
|
|
|
|
class B(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super(type(self), self).__init__()
|
|
self.calling = False
|
|
self.called = False
|
|
|
|
def forward(self, t):
|
|
if self.calling:
|
|
return t - t
|
|
else:
|
|
return t + t
|
|
|
|
def __call__(self, *args):
|
|
self.called = True
|
|
self.calling = True
|
|
return super(type(self), self).__call__(*args)
|
|
self.calling = False
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, a, b):
|
|
super().__init__()
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def forward(self, t):
|
|
x = self.a(t)
|
|
y = self.b(t)
|
|
return x + y
|
|
|
|
class LeafTracer(Tracer):
|
|
def is_leaf_module(self, module, name):
|
|
return True
|
|
|
|
class LeafTracerNotB(Tracer):
|
|
def is_leaf_module(self, module, name):
|
|
return False if "b" in name else True
|
|
|
|
# Recompile calls added "for fun", since they
|
|
# chain __call__ wrappers.
|
|
|
|
#
|
|
# Test: B as a regular, non-leaf module
|
|
#
|
|
a = symbolic_trace(A())
|
|
a.recompile()
|
|
m = M(a, B())
|
|
graph = LeafTracerNotB().trace(m)
|
|
gm = GraphModule(m, graph)
|
|
gm.recompile()
|
|
|
|
# Test graphmodule/submodule a is not inlined.
|
|
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
|
|
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
|
|
self.assertTrue(len(match) == 1)
|
|
|
|
# Test submodule b is not treated as leaf.
|
|
self.assertFalse(hasattr(gm, "b"))
|
|
|
|
# Test assert custom __call__ on submodule b was honored.
|
|
match = [
|
|
n
|
|
for n in gm.graph.nodes
|
|
if n.op == "call_function" and n.target == operator.sub
|
|
]
|
|
self.assertTrue(len(match) == 1)
|
|
|
|
#
|
|
# Test: B as a regular, leaf module
|
|
# symbolic_trace should only patch torch.nn.Module.__call__,
|
|
# which means B.__call__ should still execute
|
|
#
|
|
a = symbolic_trace(A())
|
|
a.recompile()
|
|
b = B()
|
|
m = M(a, b)
|
|
graph = LeafTracer().trace(m)
|
|
gm = GraphModule(m, graph)
|
|
gm.recompile()
|
|
|
|
# Test graphmodule/submodule a is not inlined.
|
|
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
|
|
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
|
|
self.assertTrue(len(match) == 1)
|
|
|
|
# Test submodule b is leaf:
|
|
self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
|
|
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
|
|
self.assertTrue(len(match) == 1)
|
|
|
|
# Test b.__call__ was run
|
|
self.assertTrue(b.called)
|
|
self.assertTrue(gm.get_submodule("b").called)
|
|
|
|
#
|
|
# Test: B as GraphModule leaf
|
|
# __call__ not honored since symbolic_trace directly invokes forward()
|
|
#
|
|
a = symbolic_trace(A())
|
|
a.recompile()
|
|
b = symbolic_trace(B())
|
|
b.recompile()
|
|
m = M(a, b)
|
|
graph = LeafTracer().trace(m)
|
|
gm = GraphModule(m, graph)
|
|
gm.recompile()
|
|
|
|
self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
|
|
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
|
|
self.assertTrue(len(match) == 1)
|
|
|
|
self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
|
|
match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
|
|
self.assertTrue(len(match) == 1)
|
|
|
|
def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.my_buff = torch.nn.Buffer(torch.rand(3, 4))
|
|
self.register_parameter(
|
|
"my_param", torch.nn.Parameter(torch.rand(3, 4))
|
|
)
|
|
|
|
def forward(self, x):
|
|
return x + self.my_buff + self.my_param
|
|
|
|
mod = MyModule()
|
|
mod_traced = symbolic_trace(mod)
|
|
|
|
# Create new GraphModule based on original, either w/ dict or root module.
|
|
orig_buff = mod_traced.get_buffer("my_buff")
|
|
orig_param = mod_traced.get_parameter("my_param")
|
|
mod_traced_new = GraphModule(
|
|
{"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
|
|
mod_traced.graph,
|
|
)
|
|
|
|
# Check that both my_buff and my_param are found and the same.
|
|
try:
|
|
new_buff = mod_traced_new.get_buffer("my_buff")
|
|
except Exception:
|
|
self.fail("Did not find my_buff")
|
|
self.assertEqual(orig_buff, new_buff)
|
|
|
|
try:
|
|
new_param = mod_traced_new.get_parameter("my_param")
|
|
except Exception:
|
|
self.fail("Did not find my_param")
|
|
self.assertEqual(orig_param, new_param)
|
|
|
|
x = torch.rand(3, 4)
|
|
orig_out = mod_traced(x)
|
|
submodules_out = mod_traced_new(x)
|
|
|
|
self.assertEqual(orig_out, submodules_out)
|
|
|
|
def test_graph_module_init_buffer_param_copied_dict_init(self):
|
|
self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
|
|
|
|
def test_graph_module_init_buffer_param_copied_mod_init(self):
|
|
self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
|
|
|
|
def test_annotations_with_no_forward_references(self):
|
|
class A:
|
|
def __call__(self, x: torch.Tensor):
|
|
return torch.add(x, x)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
|
|
return a(x)
|
|
|
|
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
|
|
|
|
def test_annotations_with_forward_references(self):
|
|
class A:
|
|
def __call__(self, x: torch.Tensor):
|
|
return torch.add(x, x)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
|
|
return a(x)
|
|
|
|
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
|
|
|
|
def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
|
|
class A:
|
|
def __call__(self, x: torch.Tensor):
|
|
return torch.add(x, x)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
|
|
return a(x[0])
|
|
|
|
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
|
|
|
|
def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
|
|
class A:
|
|
def __call__(self, x: torch.Tensor):
|
|
return torch.add(x, x)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
|
|
return a(x)[0]
|
|
|
|
self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
|
|
|
|
@unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
|
|
"`annotations` is not defined in Python <3.7")
|
|
def test_annotation_with_future(self):
|
|
try:
|
|
import fx.test_future # noqa: F401
|
|
finally:
|
|
del sys.modules["__future__"]
|
|
|
|
@unittest.skipIf(sys.version_info > (3, 11), "Does not work in 3.11")
|
|
def test_annotations_empty_tuple(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
|
|
return "foo"
|
|
|
|
traced = torch.fx.symbolic_trace(Foo())
|
|
|
|
x = ()
|
|
y = ("bar", ())
|
|
|
|
traced(x, y)
|
|
|
|
FileCheck().check("_Tuple[()]") \
|
|
.check("typing_Tuple[str,typing_Tuple[()]]") \
|
|
.run(traced.code)
|
|
|
|
scripted = torch.jit.script(traced)
|
|
|
|
scripted(x, y)
|
|
|
|
FileCheck().check("Tuple[()]") \
|
|
.check("Tuple[str, Tuple[()]]") \
|
|
.run(scripted.code)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Python Windows bug? https://bugs.python.org/issue45108")
|
|
@unittest.skipIf(sys.version_info >= (3, 10), "Does not work on Python-3.10")
|
|
def test_assert(self):
|
|
def f(x):
|
|
assert x > 1
|
|
return x + 1
|
|
try:
|
|
torch.fx.proxy.TracerBase.trace_asserts = True
|
|
traced = symbolic_trace(f)
|
|
finally:
|
|
torch.fx.proxy.TracerBase.trace_asserts = False
|
|
|
|
self.assertEqual(f(2), traced(2))
|
|
with self.assertRaises(AssertionError):
|
|
traced(0)
|
|
|
|
def test_pytree(self):
|
|
# Used to test that you can use your own placeholder class
|
|
class PHTest(PHBase):
|
|
pass
|
|
|
|
def f_sum(x):
|
|
return sum(x)
|
|
|
|
def f_sum_dict(x):
|
|
out = 0
|
|
for v in x.values():
|
|
out += v
|
|
return out
|
|
|
|
def f_dict_list_map(x):
|
|
new_dict = {}
|
|
for k, v in x.items():
|
|
new_dict[k] = [i + 1 for i in v]
|
|
return new_dict
|
|
|
|
def f_dict_add(x):
|
|
return x['a'] + sum(x['z'])
|
|
|
|
def f_namedtuple_add(x):
|
|
return x.x + x.y
|
|
|
|
pytree.register_pytree_node(
|
|
Foo,
|
|
lambda x: ([x.a, x.b], None),
|
|
lambda x, _: Foo(x[0], x[1]),
|
|
)
|
|
fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
|
|
|
|
def f_custom(x):
|
|
return x.a + x.b
|
|
|
|
def f_custom_dict(x):
|
|
return f_sum_dict(x.a) + x.b
|
|
|
|
def f_return_custom(x):
|
|
return Foo(x.b, x.a)
|
|
|
|
tests = [
|
|
(f_sum, [PH, PH, PH]),
|
|
(f_sum, []),
|
|
(f_sum, [PHTest(), PHTest(), PHTest()]),
|
|
(f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
|
|
(f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
|
|
(f_dict_list_map, {5: (PH, PH, PH)}),
|
|
(f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
|
|
(f_dict_add, {'a': PH, 'z': []}),
|
|
(f_custom, Foo(PH, PH)),
|
|
(f_custom, Foo(PH, 3)),
|
|
(f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
|
|
# (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
|
|
(f_namedtuple_add, Point(PH, PH)),
|
|
]
|
|
|
|
def verify_pytree(f, inp):
|
|
val = pytree.tree_map(lambda x: torch.randn(3) if isinstance(x, PHBase) else x, inp)
|
|
num_flat_args = len(pytree.tree_leaves(inp))
|
|
orig_out = f(val)
|
|
nf = symbolic_trace(f, concrete_args={'x': inp})
|
|
self.assertEqual(nf(val), orig_out)
|
|
|
|
bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
|
|
bare_fx.graph.set_codegen(CodeGen())
|
|
bare_fx.recompile()
|
|
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(val))), orig_out)
|
|
|
|
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
|
|
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
|
|
|
|
nf = symbolic_trace(nf)
|
|
self.assertEqual(nf(val), orig_out)
|
|
assert "tree_flatten_spec" not in nf.code
|
|
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == 1
|
|
|
|
nf = symbolic_trace(nf, concrete_args={'x': inp})
|
|
self.assertEqual(nf(val), orig_out)
|
|
assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
|
|
assert sum(i.op == 'placeholder' for i in nf.graph.nodes) == num_flat_args
|
|
|
|
pickled = pickle.dumps(nf)
|
|
nf = pickle.loads(pickled)
|
|
self.assertEqual(nf(val), orig_out)
|
|
|
|
for f, inp in tests:
|
|
verify_pytree(f, inp)
|
|
|
|
def test_pytree_concrete(self):
|
|
def f(b, a):
|
|
if b:
|
|
return a['a']
|
|
else:
|
|
return a['z']
|
|
|
|
inp = {'a': {'a': PH, 'z': PH}, 'b': True}
|
|
nf = symbolic_trace(f, concrete_args=inp)
|
|
val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
|
|
self.assertEqual(nf(**val), f(**val))
|
|
|
|
nf = symbolic_trace(nf)
|
|
self.assertEqual(nf(**val), f(**val))
|
|
|
|
def test_metadata_on_ph(self):
|
|
def f_sum(a: int, b: int) -> int:
|
|
return a + b
|
|
|
|
# Due to unflattening of dict, the batch argument
|
|
# will be split into two separate nodes with the names
|
|
# "batch_1" and "batch_2", referring to the keys
|
|
# "f1" and "f2" respectively in the dict.
|
|
def f_dict(a: Dict[str, str]) -> bool:
|
|
return a["f1"] == a["f2"]
|
|
|
|
def verify_metadata(gm: GraphModule, arg_names: List[str], metadata: List[str]):
|
|
for node in gm.graph.nodes:
|
|
if node.op == "placeholder":
|
|
self.assertTrue(node.name in arg_names)
|
|
self.assertTrue(node.ph_key in metadata)
|
|
|
|
verify_metadata(
|
|
gm=symbolic_trace(
|
|
f_sum,
|
|
concrete_args={"a": PHWithMeta(ph_key="a"), "b": PHWithMeta(ph_key="b")}
|
|
),
|
|
arg_names=["a_1", "b_1"],
|
|
metadata=["a", "b"]
|
|
)
|
|
verify_metadata(
|
|
gm=symbolic_trace(
|
|
f_dict,
|
|
concrete_args={"a": {"f1": PHWithMeta(ph_key="f1"), "f2": PHWithMeta(ph_key="f2")}}
|
|
),
|
|
arg_names=["a_1", "a_2"],
|
|
metadata=["f1", "f2"]
|
|
)
|
|
|
|
# Ensures that tags on nodes are NOT overwritten by PH attributes with same attr name (tag)
|
|
class TaggingTracer(Tracer):
|
|
def create_node(self, kind : str, target : Union[str, Callable],
|
|
args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
|
|
type_expr : Optional[Any] = None) -> Node:
|
|
n = super().create_node(kind, target, args, kwargs, name)
|
|
n.tag = "foo"
|
|
return n
|
|
|
|
class PHWithTag(PHBase):
|
|
def __init__(self, tag: str):
|
|
super().__init__()
|
|
|
|
self.tag = tag
|
|
|
|
g = TaggingTracer().trace(f_sum, concrete_args={"a": PHWithTag(tag="bar"), "b": PHWithTag(tag="bar")})
|
|
for n in g.nodes:
|
|
self.assertTrue(hasattr(n, "tag"))
|
|
# Ensure that tag is still "foo" and not "bar" (from PHWithTag)
|
|
self.assertEqual(n.tag, "foo")
|
|
|
|
def test_custom_codegen(self):
|
|
class ListCodeGen(CodeGen):
|
|
def gen_fn_def(self, free_vars, maybe_return_annotation):
|
|
lst_unpack = f"""
|
|
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
|
{', '.join(free_vars)} = args_list"""
|
|
return lst_unpack
|
|
|
|
def additional_globals(self):
|
|
return [('List', typing.List)]
|
|
|
|
def process_inputs(self, *inputs):
|
|
assert len(inputs) == 1
|
|
return inputs[0]
|
|
|
|
def f(a, b):
|
|
return a + b
|
|
|
|
nf = symbolic_trace(f)
|
|
vals = [torch.randn(3), torch.randn(3)]
|
|
self.assertEqual(nf(*vals), f(*vals))
|
|
|
|
nf.graph.set_codegen(ListCodeGen())
|
|
nf.recompile()
|
|
|
|
bare_fx = GraphModule({}, copy.deepcopy(nf.graph))
|
|
bare_fx.graph.set_codegen(CodeGen())
|
|
bare_fx.recompile()
|
|
|
|
self.assertEqual(nf(vals), f(*vals))
|
|
self.assertEqual(nf.graph.process_outputs(bare_fx(*nf.graph.process_inputs(vals))), f(*vals))
|
|
|
|
ts_f = torch.jit.script(nf)
|
|
self.assertEqual(nf(vals), ts_f(vals))
|
|
|
|
def test_custom_codegen_with_transformer(self):
|
|
class ListCodeGen(CodeGen):
|
|
def gen_fn_def(self, free_vars, maybe_return_annotation):
|
|
lst_unpack = f"""
|
|
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
|
{', '.join(free_vars)} = args_list"""
|
|
return lst_unpack
|
|
|
|
def additional_globals(self):
|
|
return [('List', typing.List)]
|
|
|
|
def process_inputs(self, *inputs):
|
|
assert len(inputs) == 1
|
|
return inputs[0]
|
|
|
|
def f(a, b):
|
|
return a + b
|
|
|
|
nf = symbolic_trace(f)
|
|
vals = [torch.randn(3), torch.randn(3)]
|
|
self.assertEqual(nf(*vals), f(*vals))
|
|
|
|
nf.graph.set_codegen(ListCodeGen())
|
|
nf.recompile()
|
|
self.assertEqual(nf(vals), f(*vals))
|
|
|
|
transformed_gm = Transformer(nf).transform()
|
|
self.assertEqual(nf(vals), transformed_gm(vals))
|
|
|
|
def test_interpreter_with_codegen(self):
|
|
class ListCodeGen(CodeGen):
|
|
def gen_fn_def(self, free_vars, maybe_return_annotation):
|
|
lst_unpack = f"""
|
|
def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
|
|
{', '.join(free_vars)} = args_list"""
|
|
return lst_unpack
|
|
|
|
def additional_globals(self):
|
|
return [('List', typing.List)]
|
|
|
|
def process_inputs(self, *inputs):
|
|
assert len(inputs) == 1
|
|
return inputs[0]
|
|
|
|
def generate_output(self, output_args):
|
|
return f'return list({repr(output_args)})'
|
|
|
|
def process_outputs(self, outputs):
|
|
return list(outputs)
|
|
|
|
def f(a, b):
|
|
a = a + b
|
|
b = a + b
|
|
return a, b
|
|
|
|
nf = symbolic_trace(f)
|
|
vals = [torch.randn(3), torch.randn(3)]
|
|
nf.graph.set_codegen(ListCodeGen())
|
|
nf.recompile()
|
|
self.assertEqual(Interpreter(nf).run(vals), nf(vals))
|
|
|
|
def test_imul_code_print(self):
|
|
graph = torch.fx.Graph()
|
|
a = graph.placeholder("a")
|
|
b = graph.placeholder("b")
|
|
graph.call_function(operator.imul, (a, b), {})
|
|
graph.output(a)
|
|
gm = torch.fx.GraphModule({}, graph)
|
|
gm.recompile()
|
|
self.assertEqual(gm(2, 3), 6)
|
|
self.assertIn("a *= b", gm.code)
|
|
|
|
def test_deepcopy_tracer(self):
|
|
def fn(x, y):
|
|
return (x + y).relu().sin()
|
|
|
|
tracer = Tracer()
|
|
tracer_before = copy.deepcopy(tracer)
|
|
tracer.trace(fn)
|
|
tracer_after = copy.deepcopy(tracer)
|
|
|
|
self.assertEqual(str(tracer.graph), str(tracer_after.graph))
|
|
self.assertTrue(not hasattr(tracer_before, 'graph') or str(tracer.graph) != str(tracer_before.graph))
|
|
|
|
def test_deepcopy_graphmodule(self):
|
|
m = symbolic_trace(SimpleTest())
|
|
m.meta['hello'] = 'world'
|
|
copy_m = copy.deepcopy(m)
|
|
self.assertEqual(copy_m.meta['hello'], 'world')
|
|
|
|
def test_deepcopy_no_recursion(self):
|
|
m = symbolic_trace(SimpleTest())
|
|
m.meta['hello'] = m # circular reference
|
|
copy_m = copy.deepcopy(m) # finishes
|
|
self.assertEqual(id(copy_m), id(copy_m.meta['hello']))
|
|
|
|
def test_enum(self):
|
|
from enum import Enum
|
|
|
|
class Foo(Enum):
|
|
A = 1
|
|
B = 2
|
|
|
|
def leaf_fn(arr, enum_val):
|
|
# Use the raw enum.
|
|
arr.append(enum_val)
|
|
return arr[-1].value
|
|
|
|
def foo(x):
|
|
# Pass the enum as argument.
|
|
return leaf_fn(x, Foo.A)
|
|
|
|
traced = torch.fx.symbolic_trace(foo)
|
|
self.assertEqual(foo([]), traced([]))
|
|
|
|
def test_insert_arg(self):
|
|
m = symbolic_trace(SimpleTest())
|
|
m.buf = torch.nn.Buffer(torch.tensor(0))
|
|
output_node = next(iter(reversed(m.graph.nodes)))
|
|
with m.graph.inserting_before(output_node):
|
|
a = m.graph.get_attr("buf")
|
|
r = len(output_node.args)
|
|
output_node.insert_arg(0, a)
|
|
self.assertEqual(len(output_node.args), r + 1)
|
|
self.assertEqual(len(a.users), 1)
|
|
self.assertIs(output_node.args[0], a)
|
|
self.assertIs(next(iter(a.users.keys())), output_node)
|
|
output_node.insert_arg(2, a)
|
|
self.assertEqual(len(output_node.args), r + 2)
|
|
self.assertEqual(len(a.users), 1)
|
|
self.assertIs(output_node.args[2], a)
|
|
self.assertIs(next(iter(a.users.keys())), output_node)
|
|
m.graph.lint()
|
|
|
|
def test_delete_unused_values(self):
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
# disable mutable checking temporarily
|
|
orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = False
|
|
|
|
def fn(a, b, c, d):
|
|
x = a + b
|
|
y = c + d
|
|
y.copy_(x)
|
|
x = torch.relu(x)
|
|
return x
|
|
|
|
a, b, c, d = (torch.randn(2, 4, requires_grad=False) for _ in range(4))
|
|
fx_fn = make_fx(fn)(a, b, c, d)
|
|
print(fx_fn)
|
|
|
|
fx_fn.graph.eliminate_dead_code()
|
|
py_code = fx_fn.recompile()
|
|
self.assertTrue("copy_ = torch.ops.aten.copy_.default" in py_code.src)
|
|
self.assertTrue("copy_ = None" in py_code.src)
|
|
|
|
# recorver mutable checking flag
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
|
|
|
|
def run_getitem_target():
|
|
from torch.fx._symbolic_trace import _wrapped_methods_to_patch
|
|
_wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
|
|
try:
|
|
TestFX().getitem_inner()
|
|
finally:
|
|
_wrapped_methods_to_patch.pop()
|
|
|
|
|
|
class TestOperatorSignatures(JitTestCase):
|
|
def setUp(self):
|
|
# Checking for mutable operations whil tracing is feature flagged
|
|
# Enable it in testing but not by default
|
|
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = True
|
|
|
|
def tearDown(self):
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
|
|
|
|
@onlyCPU
|
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
|
def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
|
|
if not isinstance(op.op, types.BuiltinFunctionType):
|
|
raise unittest.SkipTest("This path doesn't work on Python functions")
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
|
schemas = get_signature_for_torch_op(op.op)
|
|
if not schemas:
|
|
raise RuntimeError('No Schemas Returned')
|
|
for sample_input in sample_inputs_itr:
|
|
# Iterate through overloads until we hit a match. If we exit this
|
|
# loop via `else`, we haven't found a match
|
|
for schema in schemas:
|
|
try:
|
|
bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
|
|
bound_args.apply_defaults()
|
|
op(*bound_args.args, **bound_args.kwargs)
|
|
break
|
|
except TypeError as e:
|
|
pass
|
|
else:
|
|
raise RuntimeError(f'Did not match any schemas for op {op.name}!')
|
|
|
|
|
|
class TestFXAPIBackwardCompatibility(JitTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.maxDiff = None
|
|
|
|
# Checking for mutable operations whil tracing is feature flagged
|
|
# Enable it in testing but not by default
|
|
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = True
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
|
|
|
|
|
|
def _fn_to_stable_annotation_str(self, obj):
|
|
"""
|
|
Unfortunately we have to serialize function signatures manually since
|
|
serialization for `inspect.Signature` objects is not stable across
|
|
python versions
|
|
"""
|
|
fn_name = torch.typename(obj)
|
|
|
|
signature = inspect.signature(obj)
|
|
|
|
sig_str = f'{fn_name}{signature}'
|
|
|
|
arg_strs = []
|
|
for k, v in signature.parameters.items():
|
|
maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
|
|
if v.annotation is not inspect.Signature.empty else ''
|
|
|
|
def default_val_str(val):
|
|
if isinstance(val, (tuple, list)):
|
|
str_pieces = ['(' if isinstance(val, tuple) else '[']
|
|
str_pieces.append(', '.join(default_val_str(v) for v in val))
|
|
if isinstance(val, tuple) and len(str_pieces) == 2:
|
|
str_pieces.append(',')
|
|
str_pieces.append(')' if isinstance(val, tuple) else ']')
|
|
return ''.join(str_pieces)
|
|
|
|
# Need to fix up some default value strings.
|
|
# First case: modules. Default module `repr` contains the FS path of the module.
|
|
# Don't leak that
|
|
if isinstance(val, types.ModuleType):
|
|
return f'<module {val.__name__}>'
|
|
|
|
# Second case: callables. Callables (such as lambdas) encode their address in
|
|
# their string repr. Don't do that
|
|
if callable(val):
|
|
return f'<function {val.__name__}>'
|
|
|
|
return str(val)
|
|
|
|
if v.default is not inspect.Signature.empty:
|
|
default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
|
|
maybe_default = f' = {default_val_str}'
|
|
else:
|
|
maybe_default = ''
|
|
maybe_stars = ''
|
|
if v.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
maybe_stars = '*'
|
|
elif v.kind == inspect.Parameter.VAR_KEYWORD:
|
|
maybe_stars = '**'
|
|
arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
|
|
|
|
return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
|
|
if signature.return_annotation is not inspect.Signature.empty else ''
|
|
|
|
return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
|
|
|
|
def _annotation_type_to_stable_str(self, t, sig_str):
|
|
if t is inspect.Signature.empty:
|
|
return ''
|
|
|
|
# Forward ref
|
|
if isinstance(t, str):
|
|
return f"'{t}'"
|
|
if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
|
|
return t.__forward_arg__
|
|
if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
|
|
return t.__forward_arg__
|
|
|
|
trivial_mappings = {
|
|
str : 'str',
|
|
int : 'int',
|
|
float: 'float',
|
|
bool: 'bool',
|
|
torch.dtype: 'torch.dtype',
|
|
torch.Tensor: 'torch.Tensor',
|
|
torch.device: 'torch.device',
|
|
torch.memory_format: 'torch.memory_format',
|
|
slice: 'slice',
|
|
torch.nn.Module: 'torch.nn.modules.module.Module',
|
|
torch.fx.Graph : 'torch.fx.graph.Graph',
|
|
torch.fx.Node : 'torch.fx.node.Node',
|
|
torch.fx.Proxy : 'torch.fx.proxy.Proxy',
|
|
torch.fx.node.Target : 'torch.fx.node.Target',
|
|
torch.fx.node.Argument : 'torch.fx.node.Argument',
|
|
torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
|
|
torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
|
|
torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
|
|
Ellipsis : '...',
|
|
typing.Any: 'Any',
|
|
type(None): 'NoneType',
|
|
None: 'None',
|
|
typing.Iterator: 'Iterator',
|
|
}
|
|
|
|
mapping = trivial_mappings.get(t, None)
|
|
if mapping:
|
|
return mapping
|
|
|
|
# Handle types with contained types
|
|
contained = getattr(t, '__args__', None) or []
|
|
|
|
# Callables contain a bare List for arguments
|
|
contained = t if isinstance(t, list) else contained
|
|
|
|
# Python 3.8 puts type vars into __args__ for unbound types such as Dict
|
|
if all(isinstance(ct, typing.TypeVar) for ct in contained):
|
|
contained = []
|
|
|
|
contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
|
|
contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
|
|
|
|
|
|
origin = getattr(t, '__origin__', None)
|
|
if origin is None:
|
|
# Unbound types don't have `__origin__` in some Python versions, so fix that up here.
|
|
origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
|
|
|
|
if origin in {tuple, typing.Tuple}:
|
|
return f'Tuple{contained_type_str}'
|
|
if origin in {typing.Union}:
|
|
# Annoying hack to detect Optional
|
|
if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
|
|
not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
|
|
return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
|
|
return f'Union{contained_type_str}'
|
|
if origin in {dict, typing.Dict}:
|
|
return f'Dict{contained_type_str}'
|
|
if origin in {list, typing.List}:
|
|
return f'List{contained_type_str}'
|
|
if origin in {type, typing.Type}:
|
|
return f'Type{contained_type_str}'
|
|
if isinstance(t, typing.Callable):
|
|
if len(contained) > 0 and contained[0] is not Ellipsis:
|
|
return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
|
|
else:
|
|
return f'Callable{contained_type_str}'
|
|
|
|
raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
|
|
f'Please add support for this type and confirm with the '
|
|
f'FX team that your signature change is valid.')
|
|
|
|
|
|
def test_function_back_compat(self):
|
|
"""
|
|
Test backward compatibility for function signatures with
|
|
@compatibility(is_backward_compatible=True). Currently this checks for
|
|
exact signature matches, which may lead to false positives. If this
|
|
becomes too annoying, we can refine this check to actually parse out
|
|
the saved schema strings and check if the change is truly backward-
|
|
incompatible.
|
|
"""
|
|
signature_strs = []
|
|
|
|
for obj in _BACK_COMPAT_OBJECTS:
|
|
if not isinstance(obj, type):
|
|
signature_strs.append(self._fn_to_stable_annotation_str(obj))
|
|
|
|
signature_strs.sort()
|
|
|
|
try:
|
|
self.assertExpected('\n'.join(signature_strs) + '\n', 'fx_backcompat_function_signatures')
|
|
except AssertionError as e:
|
|
msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
|
|
f"as backwards-compatible has experienced a signature change. See the " \
|
|
f"above exception context for more information. If this change was " \
|
|
f"unintended, please revert it. If it was intended, check with the FX " \
|
|
f"team to ensure that the proper deprecation protocols have been followed " \
|
|
f"and subsequently --accept the change."
|
|
raise AssertionError(msg) # noqa: B904
|
|
|
|
def test_class_member_back_compat(self):
|
|
"""
|
|
Test backward compatibility for members of classes with
|
|
@compatibility(is_backward_compatible=True). Currently this checks for
|
|
exact matches on the publicly visible members of the class.
|
|
"""
|
|
class_method_strs = []
|
|
|
|
for obj in _BACK_COMPAT_OBJECTS:
|
|
if isinstance(obj, type):
|
|
public_members = [name for name in obj.__dict__ if not name.startswith('_')]
|
|
class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
|
|
|
|
class_method_strs.sort()
|
|
|
|
try:
|
|
self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
|
|
except AssertionError as e:
|
|
msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
|
|
f"as backwards-compatible has experienced change in its public members. See the " \
|
|
f"above exception context for more information. If this change was " \
|
|
f"unintended, please revert it. If it was intended, check with the FX " \
|
|
f"team to ensure that the proper deprecation protocols have been followed " \
|
|
f"and subsequently --accept the change."
|
|
raise AssertionError(msg) from e
|
|
|
|
def test_public_api_surface(self):
|
|
non_back_compat_objects = {}
|
|
|
|
def check_symbols_have_bc_designation(m, seen):
|
|
if not m.__name__.startswith('torch.fx'):
|
|
return
|
|
if m.__name__.startswith('torch.fx.experimental'):
|
|
return
|
|
# It's really common for inner functions to point to random modules
|
|
# - make sure we don't recurse into modules we've already checked.
|
|
seen.add(m.__name__)
|
|
for k, v in m.__dict__.items():
|
|
if hasattr(v, '__name__') and v.__name__ in seen:
|
|
continue
|
|
if v is m:
|
|
continue
|
|
if k.startswith('_'):
|
|
continue
|
|
if isinstance(v, types.ModuleType):
|
|
check_symbols_have_bc_designation(v, seen)
|
|
elif isinstance(v, (type, types.FunctionType)):
|
|
if v not in _MARKED_WITH_COMPATIBILITY:
|
|
non_back_compat_objects.setdefault(v)
|
|
|
|
check_symbols_have_bc_designation(torch.fx, set())
|
|
check_symbols_have_bc_designation(torch.fx.passes, set())
|
|
|
|
non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
|
|
# Only want objects in torch.fx
|
|
non_back_compat_strs = [
|
|
s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
|
|
# Only want objects in public namespaces
|
|
non_back_compat_strs = [
|
|
s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
|
|
non_back_compat_strs.sort()
|
|
|
|
if len(non_back_compat_strs) != 0:
|
|
raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
|
|
f"backwards-compatibility classification! Please decorate these "
|
|
f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
|
|
f"BC guarantees.")
|
|
|
|
def test_adding_side_effect_function(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
side_effect_func(x)
|
|
return x
|
|
|
|
gm = torch.fx.symbolic_trace(TestModule())
|
|
self.assertEqual(len(gm.graph.nodes), 3)
|
|
gm.graph.eliminate_dead_code()
|
|
gm.recompile()
|
|
self.assertEqual(len(gm.graph.nodes), 3)
|
|
found = False
|
|
for node in gm.graph.nodes:
|
|
if node.op == 'call_function' and node.target == side_effect_func:
|
|
found = True
|
|
self.assertTrue(found)
|
|
|
|
def test_preserve_unused_attr_after_unpickle(self):
|
|
gm = torch.fx.symbolic_trace(Add())
|
|
gm.add_submodule("foo", Add())
|
|
gm.dummy_buffer = torch.nn.Buffer(torch.empty(1))
|
|
gm.register_parameter("dummy_parameter", torch.nn.Parameter(torch.empty(1)))
|
|
b = io.BytesIO()
|
|
torch.save(gm, b)
|
|
b.seek(0)
|
|
# weights_only=False as this loads a GraphModule
|
|
# GLOBAL torch.fx.graph_module.reduce_graph_module was not an allowed global by default
|
|
reload_gm = torch.load(b, weights_only=False)
|
|
self.assertTrue(hasattr(reload_gm, "foo"))
|
|
self.assertTrue(hasattr(reload_gm, "dummy_buffer"))
|
|
self.assertTrue(hasattr(reload_gm, "dummy_parameter"))
|
|
|
|
# This is failing on Python 3.12 : https://github.com/pytorch/pytorch/issues/119454
|
|
@unittest.skipIf(
|
|
sys.version_info >= (3, 12), "Failing on python 3.12+"
|
|
)
|
|
class TestFunctionalTracing(JitTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
# Checking for mutable operations whil tracing is feature flagged
|
|
# Enable it in testing but not by default
|
|
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = True
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
|
|
|
|
IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
|
|
"has_torch_function_variadic", "handle_torch_function",
|
|
"boolean_dispatch")
|
|
TO_PATCH = {"has_torch_function": None,
|
|
"has_torch_function_unary": None,
|
|
"has_torch_function_variadic": None}
|
|
|
|
BUILT_IN_FUNC = (AssertionError, "")
|
|
PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
|
|
PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
|
|
LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
|
|
ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
|
|
CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
|
|
INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
|
|
MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
|
|
|
|
UNTRACEABLE_FUNCTIONALS = {
|
|
"adaptive_avg_pool1d": BUILT_IN_FUNC,
|
|
"avg_pool1d": BUILT_IN_FUNC,
|
|
"avg_pool2d": BUILT_IN_FUNC,
|
|
"avg_pool3d": BUILT_IN_FUNC,
|
|
"bilinear": BUILT_IN_FUNC,
|
|
"celu_": BUILT_IN_FUNC,
|
|
"channel_shuffle": BUILT_IN_FUNC,
|
|
"native_channel_shuffle": BUILT_IN_FUNC,
|
|
"conv1d": BUILT_IN_FUNC,
|
|
"conv2d": BUILT_IN_FUNC,
|
|
"conv3d": BUILT_IN_FUNC,
|
|
"conv_tbc": BUILT_IN_FUNC,
|
|
"conv_transpose1d": BUILT_IN_FUNC,
|
|
"conv_transpose2d": BUILT_IN_FUNC,
|
|
"conv_transpose3d": BUILT_IN_FUNC,
|
|
"cosine_similarity": BUILT_IN_FUNC,
|
|
"elu_": BUILT_IN_FUNC,
|
|
"gelu": BUILT_IN_FUNC,
|
|
"hardshrink": BUILT_IN_FUNC,
|
|
"hardtanh_": BUILT_IN_FUNC,
|
|
"leaky_relu_": BUILT_IN_FUNC,
|
|
"linear": BUILT_IN_FUNC,
|
|
"logsigmoid": BUILT_IN_FUNC,
|
|
"one_hot": BUILT_IN_FUNC,
|
|
"pad": ARG_TYPE_MISMATCH,
|
|
"pairwise_distance": BUILT_IN_FUNC,
|
|
"pdist": BUILT_IN_FUNC,
|
|
"pixel_shuffle": BUILT_IN_FUNC,
|
|
"pixel_unshuffle": BUILT_IN_FUNC,
|
|
"prelu": BUILT_IN_FUNC,
|
|
"relu_": BUILT_IN_FUNC,
|
|
"rrelu_": BUILT_IN_FUNC,
|
|
"selu_": BUILT_IN_FUNC,
|
|
"scaled_dot_product_attention": BUILT_IN_FUNC,
|
|
"softplus": BUILT_IN_FUNC,
|
|
"softshrink": BUILT_IN_FUNC,
|
|
"threshold_": BUILT_IN_FUNC,
|
|
|
|
"adaptive_avg_pool2d": LEN_ERROR,
|
|
"adaptive_avg_pool3d": LEN_ERROR,
|
|
"adaptive_max_pool2d_with_indices": LEN_ERROR,
|
|
"adaptive_max_pool3d_with_indices": LEN_ERROR,
|
|
"instance_norm": CONTROL_FLOW,
|
|
|
|
"adaptive_max_pool1d": PROXY_ITERABLE,
|
|
"adaptive_max_pool2d": PROXY_ITERABLE,
|
|
"adaptive_max_pool3d": PROXY_ITERABLE,
|
|
"fractional_max_pool2d": PROXY_ITERABLE,
|
|
"fractional_max_pool3d": PROXY_ITERABLE,
|
|
"max_pool1d": PROXY_ITERABLE,
|
|
"max_pool2d": PROXY_ITERABLE,
|
|
"max_pool3d": PROXY_ITERABLE,
|
|
|
|
"lp_pool2d": PROXY_ITERATED,
|
|
"lp_pool3d": PROXY_ITERATED,
|
|
"max_unpool1d": PROXY_ITERATED,
|
|
"max_unpool2d": PROXY_ITERATED,
|
|
"max_unpool3d": PROXY_ITERATED,
|
|
"fold": PROXY_ITERATED,
|
|
"unfold": PROXY_ITERATED,
|
|
|
|
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
|
|
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
|
"fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
|
|
"layer_norm": ARG_TYPE_MISMATCH,
|
|
"rms_norm": ARG_TYPE_MISMATCH,
|
|
"lp_pool1d": ARG_TYPE_MISMATCH,
|
|
|
|
"affine_grid": CONTROL_FLOW,
|
|
"alpha_dropout": CONTROL_FLOW,
|
|
"batch_norm": CONTROL_FLOW,
|
|
"binary_cross_entropy": CONTROL_FLOW,
|
|
"binary_cross_entropy_with_logits": CONTROL_FLOW,
|
|
"celu": CONTROL_FLOW,
|
|
"cosine_embedding_loss": CONTROL_FLOW,
|
|
"cross_entropy": CONTROL_FLOW,
|
|
"ctc_loss": CONTROL_FLOW,
|
|
"dropout": CONTROL_FLOW,
|
|
"dropout1d": CONTROL_FLOW,
|
|
"dropout2d": CONTROL_FLOW,
|
|
"dropout3d": CONTROL_FLOW,
|
|
"elu": CONTROL_FLOW,
|
|
"embedding": CONTROL_FLOW,
|
|
"embedding_bag": CONTROL_FLOW,
|
|
"feature_alpha_dropout": CONTROL_FLOW,
|
|
"gaussian_nll_loss": CONTROL_FLOW,
|
|
"glu": CONTROL_FLOW,
|
|
"grid_sample": CONTROL_FLOW,
|
|
"group_norm": CONTROL_FLOW,
|
|
"gumbel_softmax": CONTROL_FLOW,
|
|
"hardsigmoid": CONTROL_FLOW,
|
|
"hardswish": CONTROL_FLOW,
|
|
"hardtanh": CONTROL_FLOW,
|
|
"hinge_embedding_loss": CONTROL_FLOW,
|
|
"huber_loss": CONTROL_FLOW,
|
|
"interpolate": CONTROL_FLOW,
|
|
"kl_div": CONTROL_FLOW,
|
|
"l1_loss": CONTROL_FLOW,
|
|
"leaky_relu": CONTROL_FLOW,
|
|
"local_response_norm": CONTROL_FLOW,
|
|
"margin_ranking_loss": CONTROL_FLOW,
|
|
"max_pool1d_with_indices": ARG_TYPE_MISMATCH,
|
|
"max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
|
"max_pool3d_with_indices": ARG_TYPE_MISMATCH,
|
|
"mse_loss": CONTROL_FLOW,
|
|
"multi_head_attention_forward": CONTROL_FLOW,
|
|
"multi_margin_loss": CONTROL_FLOW,
|
|
"multilabel_margin_loss": CONTROL_FLOW,
|
|
"multilabel_soft_margin_loss": CONTROL_FLOW,
|
|
"nll_loss": CONTROL_FLOW,
|
|
"poisson_nll_loss": CONTROL_FLOW,
|
|
"relu": CONTROL_FLOW,
|
|
"relu6": CONTROL_FLOW,
|
|
"rrelu": CONTROL_FLOW,
|
|
"selu": CONTROL_FLOW,
|
|
"silu": CONTROL_FLOW,
|
|
"mish": CONTROL_FLOW,
|
|
"smooth_l1_loss": CONTROL_FLOW,
|
|
"soft_margin_loss": CONTROL_FLOW,
|
|
"threshold": CONTROL_FLOW,
|
|
"triplet_margin_loss": CONTROL_FLOW,
|
|
"triplet_margin_with_distance_loss": CONTROL_FLOW,
|
|
"upsample": CONTROL_FLOW,
|
|
|
|
"upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
|
|
"upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
|
|
}
|
|
|
|
# List of nn.functionals with Tensor inputs but not with type annotation
|
|
FUNCTIONALS_WITHOUT_ANNOTATION = (
|
|
"adaptive_max_pool1d",
|
|
"adaptive_max_pool2d",
|
|
"adaptive_max_pool3d",
|
|
"fractional_max_pool2d",
|
|
"fractional_max_pool3d",
|
|
"max_pool1d",
|
|
"max_pool2d",
|
|
"max_pool3d",
|
|
"gaussian_nll_loss",
|
|
"upsample",
|
|
"upsample_bilinear",
|
|
"upsample_nearest",
|
|
)
|
|
|
|
# Inconsistent behavior between Python 3.8 and other Python versions:
|
|
# - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
|
|
# - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
|
|
# internal exception above
|
|
# Use the following map to override the expected exception for Python 3.8
|
|
UNTRACEABLE_FUNCTIONALS_PY38 = {
|
|
"adaptive_max_pool1d": PROXY_ITERATED,
|
|
"adaptive_max_pool2d": PROXY_ITERATED,
|
|
"adaptive_max_pool3d": PROXY_ITERATED,
|
|
"fractional_max_pool2d": PROXY_ITERATED,
|
|
"fractional_max_pool3d": PROXY_ITERATED,
|
|
"max_pool1d": PROXY_ITERATED,
|
|
"max_pool2d": PROXY_ITERATED,
|
|
"max_pool3d": PROXY_ITERATED,
|
|
|
|
"group_norm": CONTROL_FLOW
|
|
}
|
|
|
|
@classmethod
|
|
def _get_functional(cls):
|
|
functional_list = []
|
|
for f in dir(torch.nn.functional):
|
|
if not f.islower():
|
|
continue
|
|
# Ignore internal functions
|
|
if f.startswith('_'):
|
|
continue
|
|
# Ignore supporting functions
|
|
if f in cls.IGNORE_FUNCS:
|
|
continue
|
|
fn = getattr(torch.nn.functional, f)
|
|
# Ignore non-callable object like modules
|
|
if not isinstance(fn, Callable):
|
|
continue
|
|
if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
|
|
try:
|
|
sig = inspect.signature(fn)
|
|
has_tensor_arg = False
|
|
for param in sig.parameters.values():
|
|
if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
|
|
has_tensor_arg = True
|
|
if not has_tensor_arg:
|
|
continue
|
|
# No signature or Object is not supported
|
|
except ValueError:
|
|
pass
|
|
functional_list.append((f, fn))
|
|
return functional_list
|
|
|
|
@classmethod
|
|
def generate_test_func(cls, func_name, fn):
|
|
|
|
def functional_test(self):
|
|
if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
|
|
sys.version_info >= (3, 8) and sys.version_info < (3, 12):
|
|
exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
|
|
with self.assertRaisesRegex(exc, err):
|
|
symbolic_trace(fn)
|
|
elif func_name in self.UNTRACEABLE_FUNCTIONALS:
|
|
exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
|
|
with self.assertRaisesRegex(exc, err):
|
|
symbolic_trace(fn)
|
|
else:
|
|
symbolic_trace(fn)
|
|
return functional_test
|
|
|
|
@classmethod
|
|
def generate_tests(cls):
|
|
functional_list = cls._get_functional()
|
|
for func_name, fn in functional_list:
|
|
test_name = "test_nn_functional_" + func_name
|
|
functional_test = cls.generate_test_func(func_name, fn)
|
|
setattr(cls, test_name, functional_test)
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
|
|
def no(*args, **kwargs):
|
|
return False
|
|
|
|
for name in cls.TO_PATCH.keys():
|
|
cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
|
|
setattr(torch.nn.functional, name, no)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
for name in cls.TO_PATCH.keys():
|
|
setattr(torch.nn.functional, name, cls.TO_PATCH[name])
|
|
|
|
TestFunctionalTracing.generate_tests()
|
|
|
|
|
|
instantiate_device_type_tests(TestOperatorSignatures, globals())
|
|
|
|
@skipIfTorchDynamo("too slow")
|
|
@skipIfNoTorchVision
|
|
class TestVisionTracing(JitTestCase):
|
|
def setUp(self):
|
|
# Checking for mutable operations while tracing is feature flagged
|
|
# Enable it in testing but not by default
|
|
self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = True
|
|
|
|
def tearDown(self):
|
|
torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
|
|
|
|
PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
|
|
INCONSISTENT_TYPE = (
|
|
RuntimeError,
|
|
r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
|
|
)
|
|
|
|
UNTRACEABLE_MODELS = {
|
|
"fasterrcnn_resnet50_fpn": PROXY_ITERATED,
|
|
"fasterrcnn_resnet50_fpn_v2": PROXY_ITERATED,
|
|
"fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
|
|
"fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
|
|
"maskrcnn_resnet50_fpn": PROXY_ITERATED,
|
|
"maskrcnn_resnet50_fpn_v2": PROXY_ITERATED,
|
|
"keypointrcnn_resnet50_fpn": PROXY_ITERATED,
|
|
"retinanet_resnet50_fpn": PROXY_ITERATED,
|
|
"retinanet_resnet50_fpn_v2": PROXY_ITERATED,
|
|
"ssd300_vgg16": PROXY_ITERATED,
|
|
"fcos_resnet50_fpn": PROXY_ITERATED,
|
|
"ssdlite320_mobilenet_v3_large": PROXY_ITERATED,
|
|
}
|
|
UNSCRIPTABLE_MODELS = {
|
|
"googlenet": INCONSISTENT_TYPE,
|
|
"inception_v3": INCONSISTENT_TYPE,
|
|
}
|
|
|
|
output_transform = {
|
|
"fcn_resnet50": lambda x: x["out"],
|
|
"fcn_resnet101": lambda x: x["out"],
|
|
"deeplabv3_resnet50": lambda x: x["out"],
|
|
"deeplabv3_resnet101": lambda x: x["out"],
|
|
"deeplabv3_mobilenet_v3_large": lambda x: x["out"],
|
|
"lraspp_mobilenet_v3_large": lambda x: x["out"],
|
|
"fasterrcnn_resnet50_fpn": lambda x: x[1],
|
|
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
|
|
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
|
|
"maskrcnn_resnet50_fpn": lambda x: x[1],
|
|
"keypointrcnn_resnet50_fpn": lambda x: x[1],
|
|
"retinanet_resnet50_fpn": lambda x: x[1],
|
|
}
|
|
|
|
@classmethod
|
|
def generate_test_fn(cls, name, x, kwargs):
|
|
def run_test(self):
|
|
model = torchvision_models.get_model(name, **kwargs)
|
|
model = model.eval()
|
|
if name in self.UNTRACEABLE_MODELS:
|
|
err, exc = self.UNTRACEABLE_MODELS[name]
|
|
with self.assertRaisesRegex(err, exc):
|
|
graph = symbolic_trace(model)
|
|
else:
|
|
out_transform = self.output_transform.get(name, lambda x: x)
|
|
graph : torch.fx.GraphModule = symbolic_trace(model)
|
|
a = out_transform(model(x))
|
|
b = out_transform(graph(x))
|
|
self.assertEqual(a, b)
|
|
|
|
if name in self.UNSCRIPTABLE_MODELS:
|
|
err, exc = self.UNSCRIPTABLE_MODELS[name]
|
|
with self.assertRaisesRegex(err, exc):
|
|
script = torch.jit.script(graph)
|
|
else:
|
|
script = torch.jit.script(graph)
|
|
c = out_transform(script(x))
|
|
self.assertEqual(a, c)
|
|
|
|
return run_test
|
|
|
|
@classmethod
|
|
def generate_classification_tests(cls):
|
|
for k in torchvision_models.list_models(module=torchvision_models):
|
|
test_name = 'test_torchvision_models_' + k
|
|
x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
|
|
kwargs = dict(num_classes=50)
|
|
model_test = cls.generate_test_fn(k, x, kwargs)
|
|
setattr(cls, test_name, model_test)
|
|
|
|
@classmethod
|
|
def generate_segmentation_tests(cls):
|
|
for k in torchvision_models.list_models(module=torchvision_models.segmentation):
|
|
test_name = 'test_torchvision_models_segmentation_' + k
|
|
x = torch.rand(1, 3, 32, 32)
|
|
kwargs = dict(num_classes=10, pretrained_backbone=False)
|
|
model_test = cls.generate_test_fn(k, x, kwargs)
|
|
setattr(cls, test_name, model_test)
|
|
|
|
@classmethod
|
|
def generate_detection_tests(cls):
|
|
for k in torchvision_models.list_models(module=torchvision_models.detection):
|
|
test_name = 'test_torchvision_models_detection_' + k
|
|
x = [torch.rand(3, 300, 300)]
|
|
kwargs = dict(num_classes=10, pretrained_backbone=False)
|
|
model_test = cls.generate_test_fn(k, x, kwargs)
|
|
setattr(cls, test_name, model_test)
|
|
|
|
@classmethod
|
|
def generate_video_tests(cls):
|
|
for k in torchvision_models.list_models(module=torchvision_models.video):
|
|
test_name = 'test_torchvision_models_video_' + k
|
|
x = (
|
|
torch.rand(1, 3, 4, 112, 112)
|
|
if k not in {"mvit_v1_b", "mvit_v2_s", "s3d"}
|
|
else torch.rand(1, 3, 16, 224, 224)
|
|
)
|
|
kwargs = dict(num_classes=50)
|
|
model_test = cls.generate_test_fn(k, x, kwargs)
|
|
setattr(cls, test_name, model_test)
|
|
|
|
@classmethod
|
|
def generate_tests(cls):
|
|
cls.generate_classification_tests()
|
|
cls.generate_detection_tests()
|
|
cls.generate_segmentation_tests()
|
|
cls.generate_video_tests()
|
|
|
|
if HAS_TORCHVISION:
|
|
TestVisionTracing.generate_tests()
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|