mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65941 OpInfos for: empty_like, zeros_like, ones_like, full_like, randn_like Test Plan: - run tests Reviewed By: dagitses Differential Revision: D31452625 Pulled By: zou3519 fbshipit-source-id: 5e6c45918694853f9252488d62bb7f4ccfa1f1e4
		
			
				
	
	
		
			3891 lines
		
	
	
		
			139 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			3891 lines
		
	
	
		
			139 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import builtins
 | |
| import contextlib
 | |
| import copy
 | |
| import functools
 | |
| import inspect
 | |
| import math
 | |
| import numbers
 | |
| import operator
 | |
| import os
 | |
| import pickle
 | |
| import sys
 | |
| import torch
 | |
| import traceback
 | |
| import typing
 | |
| import types
 | |
| import warnings
 | |
| import unittest
 | |
| from math import sqrt
 | |
| from torch.multiprocessing import Process
 | |
| from torch.testing import FileCheck
 | |
| from torch.testing._internal.common_methods_invocations import op_db
 | |
| from torch.testing._internal.common_device_type import ops, onlyCPU, instantiate_device_type_tests
 | |
| import torch.utils._pytree as pytree
 | |
| import torch.fx._pytree as fx_pytree
 | |
| from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Interpreter, Tracer, Transformer, Graph, wrap, PH
 | |
| import torch._C._fx
 | |
| from torch.fx.node import Target, Argument
 | |
| from torch.fx.passes import shape_prop
 | |
| from torch.fx.immutable_collections import immutable_dict, immutable_list
 | |
| from torch.fx.experimental.rewriter import RewritingTracer
 | |
| from torch.fx.operator_schemas import get_signature_for_torch_op
 | |
| from copy import deepcopy
 | |
| from collections import namedtuple
 | |
| 
 | |
| from torch.fx.proxy import TraceError
 | |
| from torch.fx._compatibility import _BACK_COMPAT_OBJECTS, _MARKED_WITH_COMATIBLITY
 | |
| 
 | |
| from fx.test_subgraph_rewriter import TestSubgraphRewriter  # noqa: F401
 | |
| from fx.test_dce_pass import TestDCE  # noqa: F401
 | |
| from fx.test_fx_const_fold import TestConstFold  # noqa: F401
 | |
| from fx.test_fx_param_shape_control_flow import TestConstParamShapeInControlFlow  # noqa: F401
 | |
| 
 | |
| if sys.version_info >= (3, 7):
 | |
|     from fx.test_gradual_type import AnnotationsTest  # noqa: F401
 | |
| if sys.version_info >= (3, 7):
 | |
|     from fx.test_gradual_type import TypeCheckerTest  # noqa: F401
 | |
| from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
 | |
| from torch.testing._internal.common_utils import (
 | |
|     IS_FBCODE,
 | |
|     IS_MACOS,
 | |
|     IS_WINDOWS,
 | |
|     TEST_WITH_ROCM,
 | |
|     find_library_location,
 | |
|     run_tests,
 | |
| )
 | |
| from torch.testing._internal.jit_utils import JitTestCase
 | |
| 
 | |
| from fx.named_tup import MyNamedTup
 | |
| 
 | |
| try:
 | |
|     from torchvision import models as torchvision_models
 | |
|     HAS_TORCHVISION = True
 | |
| except ImportError:
 | |
|     HAS_TORCHVISION = False
 | |
| skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
 | |
| 
 | |
| class SimpleTest(torch.nn.Module):
 | |
|     def forward(self, x):
 | |
|         return torch.relu(x + 3.0)
 | |
| 
 | |
| def a_non_torch_leaf(a, b):
 | |
|     return a + b
 | |
| 
 | |
| # Used for test_autowrap_function. Autowrapped functions need to be global
 | |
| def fx_int(x: float) -> int:
 | |
|     return int(x)
 | |
| 
 | |
| def fx_int_x2(x: float) -> int:
 | |
|     return int(x) * 2
 | |
| 
 | |
| # used in test_pytree. It's all the way out here because pickling a GraphModule
 | |
| # that uses Point errors out if Point is local to the function
 | |
| Point = namedtuple('Point', ['x', 'y'])
 | |
| 
 | |
| # Test wrap() passing both a function name as well as a function
 | |
| # directly
 | |
| def a_lifted_leaf(a, b):
 | |
|     return a[0] + a[1] + b
 | |
| 
 | |
| wrap('a_lifted_leaf')
 | |
| # Test wrapping twice doesn't break anything
 | |
| wrap('a_lifted_leaf')
 | |
| 
 | |
| def a_lifted_leaf2(a, b):
 | |
|     return a[0] + a[1] + b
 | |
| 
 | |
| wrap(a_lifted_leaf2)
 | |
| 
 | |
| wrap('len')
 | |
| 
 | |
| wrap('getattr')
 | |
| 
 | |
| @wrap
 | |
| def wrapped_via_decorator(a):
 | |
|     return a + 1
 | |
| 
 | |
| wrap('wrapped_with_submodule')
 | |
| 
 | |
| def wrapped_with_submodule(x: torch.Tensor, batchnorm1d: torch.nn.BatchNorm1d):
 | |
|     return batchnorm1d(x)
 | |
| 
 | |
| 
 | |
| real_wrapped_via_decorator = wrapped_via_decorator
 | |
| real_a_lifed_leaf = a_lifted_leaf
 | |
| real_a_lifed_leaf2 = a_lifted_leaf2
 | |
| _sqrt = sqrt
 | |
| 
 | |
| wrap('wrapper_fn')
 | |
| 
 | |
| def wrapper_fn(x):
 | |
|     return torch.foo(x)
 | |
| 
 | |
| class Pair(NamedTuple):
 | |
|     x : torch.Tensor
 | |
|     y : torch.Tensor
 | |
| 
 | |
| # for testing pytrees
 | |
| class Foo(object):  # noqa: B209
 | |
|     def __init__(self, a, b):
 | |
|         self.a = a
 | |
|         self.b = b
 | |
| 
 | |
| class TestFX(JitTestCase):
 | |
|     def setUp(self):
 | |
|         # Checking for mutable operations whil tracing is feature flagged
 | |
|         # Enable it in testing but not by default
 | |
|         self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = True
 | |
| 
 | |
|         if not (TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS):
 | |
|             lib_file_path = find_library_location('libtorchbind_test.so')
 | |
|             torch.ops.load_library(str(lib_file_path))
 | |
| 
 | |
|     def tearDown(self):
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
 | |
| 
 | |
|     def checkGraphModule(self, m: torch.nn.Module, args, kwargs=None):
 | |
|         """Check that an nn.Module's results match the GraphModule version
 | |
|         for a given set of args/kwargs.
 | |
|         """
 | |
|         kwargs = kwargs if kwargs else {}
 | |
|         ref_outs = m(*args, **kwargs)
 | |
|         gm = symbolic_trace(m)
 | |
|         gm.graph.lint()
 | |
|         test_outs = gm(*args, **kwargs)
 | |
|         self.assertEqual(ref_outs, test_outs)
 | |
| 
 | |
|     def test_graph_module(self):
 | |
|         class MySub(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.w = torch.nn.Parameter(torch.rand(4, 3))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.w + x
 | |
| 
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.lin = torch.nn.Linear(4, 3)
 | |
|                 self.sub_mod = MySub()
 | |
|                 self.w = torch.nn.Parameter(torch.rand(3))
 | |
| 
 | |
|             def forward(self, A, B, c):
 | |
|                 t = torch.sigmoid(A) + self.lin(c)
 | |
|                 return self.sub_mod(t.data + self.w + t + 1 - A + B // A + -A + A.add(B, alpha=3))
 | |
| 
 | |
|         m = MyModule()
 | |
|         gm = symbolic_trace(m)
 | |
| 
 | |
|         ms = torch.jit.script(gm)
 | |
| 
 | |
|         class M2(torch.nn.Module):
 | |
|             def forward(self, A):
 | |
|                 m, idx = torch.max(A, 0)
 | |
|                 return m + 1, idx + 1
 | |
| 
 | |
|         m2 = M2()
 | |
|         gm2 = symbolic_trace(m2)
 | |
| 
 | |
|         class T(torch.nn.Module):
 | |
| 
 | |
|             def forward(self, A, b=4, *args, c=5, **kwargs):
 | |
|                 x = A + 1 + args[0] + kwargs['3']
 | |
|                 return x
 | |
| 
 | |
|         t = T()
 | |
|         symbolic_trace(t)
 | |
| 
 | |
|         # test for issue described at https://github.com/pytorch/pytorch/issues/63883
 | |
|         class M3(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.relu(x)
 | |
| 
 | |
|         m3 = M3()
 | |
|         gm3 = symbolic_trace(m3)
 | |
|         new_instance = gm3.__new__(type(gm3))
 | |
|         new_instance.__init__(gm3, gm3.graph)
 | |
| 
 | |
|         x = torch.randn(5, 3)
 | |
|         torch.testing.assert_allclose(new_instance(x), torch.relu(x))
 | |
| 
 | |
|     def test_custom_import(self):
 | |
|         graph = torch.fx.Graph()
 | |
|         a = graph.placeholder('x')
 | |
|         b = graph.placeholder('y')
 | |
|         c = graph.call_function(a_non_torch_leaf, (a, b))
 | |
|         d = graph.call_function(torch.sin, (c,))
 | |
|         graph.output(d)
 | |
|         gm = GraphModule(torch.nn.Module(), graph)
 | |
|         x, y = torch.rand(1), torch.rand(1)
 | |
|         self.assertEqual(torch.sin(x + y), gm(x, y))
 | |
| 
 | |
|     def test_args_kwargs(self):
 | |
|         class T(torch.nn.Module):
 | |
|             def forward(self, *args, **kwargs):
 | |
|                 x = args[0] + kwargs['foo']
 | |
|                 return x
 | |
| 
 | |
|         t = T()
 | |
|         self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
 | |
| 
 | |
|     def test_args_kwargs_no_self(self):
 | |
|         class T(torch.nn.Module):
 | |
|             def forward(*args, **kwargs):  # noqa: B902
 | |
|                 self = args[0]
 | |
|                 return torch.relu(args[1])
 | |
| 
 | |
|         t = T()
 | |
|         with self.assertRaisesRegex(RuntimeError, r'cannot be part of \*args expansion'):
 | |
|             self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
 | |
| 
 | |
|     def test_fx_shifts(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return x << 3, x >> 3
 | |
| 
 | |
|         input = torch.LongTensor(10).random_(0, 1024)
 | |
| 
 | |
|         m = MyModule()
 | |
|         self.checkGraphModule(m, (input,))
 | |
| 
 | |
|     def test_fx_and_or(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return x & x, x | x
 | |
| 
 | |
|         input = torch.LongTensor(10).random_(0, 1024)
 | |
| 
 | |
|         m = MyModule()
 | |
|         self.checkGraphModule(m, (input,))
 | |
| 
 | |
|     def test_dict(self):
 | |
|         class MyDictMod(torch.nn.Module):
 | |
|             def forward(self, d):
 | |
|                 return d['3'].relu(), {'4' : d['3'].neg()}
 | |
| 
 | |
|         input_dict = {'3': torch.rand(3, 4)}
 | |
|         m = MyDictMod()
 | |
| 
 | |
|         self.checkGraphModule(m, (input_dict,))
 | |
| 
 | |
|     def test_matmul_tracing(self):
 | |
|         const = torch.randn(3)
 | |
| 
 | |
|         def matmul_f(x):
 | |
|             return x @ const
 | |
| 
 | |
|         mod = symbolic_trace(matmul_f)
 | |
|         inp = torch.randn(3)
 | |
|         self.assertEqual(mod(inp), matmul_f(inp))
 | |
| 
 | |
|         def rmatmul_f(x):
 | |
|             return const @ x
 | |
| 
 | |
|         mod = symbolic_trace(rmatmul_f)
 | |
|         inp = torch.randn(3)
 | |
|         self.assertEqual(mod(inp), rmatmul_f(inp))
 | |
| 
 | |
| 
 | |
|     def test_disallow_override(self):
 | |
|         # Custom delegate to disallow in-place tensor operations
 | |
|         class NoMutableCallTracer(Tracer):
 | |
|             def create_node(self, kind : str, target : Union[str, Callable],
 | |
|                             args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
 | |
|                             type_expr : Optional[Any] = None) -> Node:
 | |
|                 name = target if isinstance(target, str) else torch.typename(target)
 | |
|                 if name[-1] == '_':
 | |
|                     raise RuntimeError('In-place operations are not supported')
 | |
|                 return super().create_node(kind, target, args, kwargs, name)
 | |
| 
 | |
|         # Test method
 | |
|         class MyInplaceMod(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 x.add_(3.0)
 | |
|                 return x
 | |
| 
 | |
|         m = MyInplaceMod()
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
 | |
|             NoMutableCallTracer().trace(m)
 | |
| 
 | |
|         # Test free function
 | |
|         class MyInplaceMod2(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 torch.log_(x)
 | |
|                 return x
 | |
|         m2 = MyInplaceMod2()
 | |
|         with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
 | |
|             NoMutableCallTracer().trace(m2)
 | |
| 
 | |
|         # Test symbolic node as an arg
 | |
|         class MyInplaceMod3(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 y = torch.ones(3, 4)
 | |
|                 y.add_(x)
 | |
|                 return x
 | |
|         m3 = MyInplaceMod3()
 | |
|         with self.assertRaisesRegex(RuntimeError, 'In-place operations'):
 | |
|             NoMutableCallTracer().trace(m3)
 | |
| 
 | |
|     def test_leaf_module(self):
 | |
|         # Custom delegate to make it so that there are no leaf modules, everything
 | |
|         # should get traced through
 | |
|         class NoLeafModulesTracer(Tracer):
 | |
|             def is_leaf_module(self, m, qualname):
 | |
|                 return False
 | |
| 
 | |
|         class MyReluMod(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.relu = torch.nn.ReLU()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.relu(x)
 | |
| 
 | |
|         mrm = MyReluMod()
 | |
|         sym = NoLeafModulesTracer().trace(mrm)
 | |
|         for node in sym.nodes:
 | |
|             self.assertNotEqual(node.op, 'call_module')
 | |
|         sym.lint()
 | |
| 
 | |
|     def test_wrap(self):
 | |
|         self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
 | |
| 
 | |
|         def to_trace(y):
 | |
|             return a_lifted_leaf((4, y), 3) + a_lifted_leaf((3, 4), 5) + a_lifted_leaf((y, y), y)
 | |
| 
 | |
|         m = symbolic_trace(to_trace)
 | |
|         self.assertIn('a_lifted_leaf', m.code)
 | |
|         self.assertEqual(27, m(2))
 | |
|         self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
 | |
| 
 | |
|     def test_wrap_fn_directly(self):
 | |
|         self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
 | |
| 
 | |
|         def to_trace(y):
 | |
|             return a_lifted_leaf2((4, y), 3) + a_lifted_leaf2((3, 4), 5) + a_lifted_leaf2((y, y), y)
 | |
| 
 | |
|         m = symbolic_trace(to_trace)
 | |
|         self.assertIn('a_lifted_leaf2', m.code)
 | |
|         self.assertEqual(27, m(2))
 | |
|         self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
 | |
| 
 | |
|     def test_wrapped_via_decorator(self):
 | |
|         self.assertEqual(wrapped_via_decorator(0), 1)
 | |
| 
 | |
|         def to_trace(y):
 | |
|             return wrapped_via_decorator(y)
 | |
| 
 | |
|         m = symbolic_trace(to_trace)
 | |
|         self.assertIn('wrapped_via_decorator', m.code)
 | |
|         self.assertEqual(m(0), 1)
 | |
|         self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
 | |
|         self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
 | |
| 
 | |
|     def test_wrapped_via_decorator_and_transformed(self):
 | |
|         self.assertEqual(wrapped_via_decorator(0), 1)
 | |
| 
 | |
|         def to_trace(y):
 | |
|             return wrapped_via_decorator(y)
 | |
| 
 | |
|         m = symbolic_trace(to_trace)
 | |
|         self.assertIn('wrapped_via_decorator', m.code)
 | |
|         self.assertEqual(m(0), 1)
 | |
|         self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
 | |
|         self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
 | |
| 
 | |
|         transformed = torch.fx.Transformer(m).transform()
 | |
|         self.assertIn('wrapped_via_decorator', transformed.code)
 | |
|         self.assertEqual(transformed(0), 1)
 | |
|         self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
 | |
|         self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
 | |
| 
 | |
|     def test_wrap_with_submodule(self):
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(M, self).__init__()
 | |
|                 self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
 | |
| 
 | |
|             def forward(self, x: torch.Tensor):
 | |
|                 return wrapped_with_submodule(x, self.batchnorm1d)
 | |
| 
 | |
|         m = symbolic_trace(M())
 | |
| 
 | |
|         self.assertIn("wrapped_with_submodule", m.code)
 | |
| 
 | |
|         input = torch.rand(3, 2)
 | |
|         ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
 | |
|         self.assertEqual(ref_batchnorm1d(input), m(input))
 | |
| 
 | |
|     def test_wrapped_retrace(self):
 | |
|         def to_trace(y):
 | |
|             return wrapped_via_decorator(y)
 | |
| 
 | |
|         m = symbolic_trace(to_trace)
 | |
|         self.assertIn('wrapped_via_decorator', m.code)
 | |
|         self.assertEqual(m(0), 1)
 | |
| 
 | |
|         retraced = symbolic_trace(m)
 | |
|         self.assertIn('wrapped_via_decorator', retraced.code)
 | |
|         self.assertEqual(retraced(0), 1)
 | |
| 
 | |
|     def test_graph_edit_with_proxy(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 return a + b
 | |
|         m = M()
 | |
|         g = symbolic_trace(m).graph
 | |
|         new_g = torch.fx.Graph()
 | |
|         val_map : Dict[Node, Node] = {}
 | |
|         output_val = new_g.graph_copy(g, val_map)
 | |
|         t = Proxy(output_val)
 | |
|         # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
 | |
|         new_g.output((t + t).node)
 | |
|         gm = GraphModule(m, new_g)
 | |
|         gm.graph.lint()
 | |
|         self.assertEqual(gm(3, 4), 14)
 | |
| 
 | |
|     def test_graph_unique_names(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 return a + b
 | |
|         m = M()
 | |
|         g = symbolic_trace(m).graph
 | |
|         new_g = torch.fx.Graph()
 | |
|         val_map : Dict[Node, Node] = {}
 | |
|         output_val = new_g.graph_copy(g, val_map)
 | |
|         t = Proxy(output_val)
 | |
|         # test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
 | |
|         new_g.output((t + t).node)
 | |
|         gm = GraphModule(m, new_g)
 | |
|         seen_names : Set[str] = set()
 | |
|         for node in gm.graph.nodes:
 | |
|             assert node.name not in seen_names
 | |
|             seen_names.add(node.name)
 | |
| 
 | |
|     def test_stack_traces(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 return a + b
 | |
| 
 | |
|         tracer = torch.fx.Tracer()
 | |
|         tracer.record_stack_traces = True
 | |
| 
 | |
|         graph = tracer.trace(M())
 | |
|         for node in graph.nodes:
 | |
|             if node.op == 'output':
 | |
|                 continue
 | |
|             self.assertTrue(node.stack_trace is not None)
 | |
|             assert 'test_fx.py' in node.stack_trace
 | |
| 
 | |
|     def test_graph_unique_names_manual(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         a : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,), name='foo_1_1')
 | |
|         c : torch.fx.Node = graph.create_node('get_attr', 'y_attr', name='foo_1')
 | |
|         d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
 | |
|         graph.output(d)
 | |
|         graph2 = torch.fx.Graph()
 | |
|         val_map : Dict[Node, Node] = {}
 | |
|         graph2.graph_copy(graph, val_map)
 | |
|         seen_names : Set[str] = set()
 | |
|         for node in graph2.nodes:
 | |
|             assert node.name not in seen_names
 | |
|             seen_names.add(node.name)
 | |
| 
 | |
|     def test_unpack(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 c, d = a
 | |
|                 return c + d + b
 | |
| 
 | |
|         a = (torch.rand(1), torch.rand(1))
 | |
|         b = torch.rand(1)
 | |
|         m = M()
 | |
|         self.checkGraphModule(m, (a, b))
 | |
| 
 | |
|     def test_native_callable(self):
 | |
|         if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS:
 | |
|             raise unittest.SkipTest("non-portable load_library call used in test")
 | |
|         # This test exercises the case where we use FX to translate from Python
 | |
|         # code to some native callable object
 | |
|         #
 | |
|         # For the purposes of testing, we use ElementwiseInterpreter defined
 | |
|         # in test_custom_class.cpp.
 | |
|         #
 | |
|         # We test that we can
 | |
|         # 1) Construct a native callable from FX IR
 | |
|         # 2) Construct a drop-in replacement module that delegates to the
 | |
|         #    native callable rather than the original code
 | |
|         # 3) Run both the original code and native callable wrapper with
 | |
|         #    equivalent results
 | |
|         # 4) TorchScript compile the native callable wrapper and confirm
 | |
|         #    equivalent results with the reference
 | |
|         # 5) TorchScript serialize and deserialize the native callable
 | |
|         #    and confirm equivalent results with the reference
 | |
| 
 | |
|         # We use this simple Module as a reference computation
 | |
|         class MySimpleMod(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return 3.0 * x + x
 | |
| 
 | |
|         msm = MySimpleMod()
 | |
| 
 | |
|         # This is what a lowering pass might look like: a function that takes
 | |
|         # a valid nn.Module, symbolically traces it, lowers the Module to some
 | |
|         # representation, and wraps that representation up into another
 | |
|         # nn.Module instance that handles dispatch to the compiled/lowered code.
 | |
|         def lower_to_elementwise_interpreter(orig_mod : torch.nn.Module) -> torch.nn.Module:
 | |
|             # ===== Stage 1: Symbolic trace the module =====
 | |
|             mod = symbolic_trace(orig_mod)
 | |
| 
 | |
|             # ===== Stage 2: Lower GraphModule representation to the C++
 | |
|             #       interpreter's instruction format ======
 | |
|             instructions = []
 | |
|             constant_idx = 0
 | |
|             constants = {}
 | |
|             fn_input_names = []
 | |
| 
 | |
|             target_to_name = {
 | |
|                 operator.add : "add",
 | |
|                 operator.mul : "mul"
 | |
|             }
 | |
| 
 | |
|             output_node : Optional[Node] = None
 | |
|             # For each instruction, create a triple
 | |
|             # (instruction_name : str, inputs : List[str], output : str)
 | |
|             # to feed into the C++ interpreter
 | |
|             for n in mod.graph.nodes:
 | |
|                 target, args, out_name = n.target, n.args, n.name
 | |
|                 assert len(n.kwargs) == 0, "kwargs currently not supported"
 | |
| 
 | |
|                 if n.op == 'placeholder':
 | |
|                     # Placeholders specify function argument names. Save these
 | |
|                     # for later when we generate the wrapper GraphModule
 | |
|                     fn_input_names.append(target)
 | |
|                 elif n.op == 'call_function':
 | |
|                     assert target in target_to_name, "Unsupported call target " + target
 | |
|                     arg_names = []
 | |
|                     for arg in args:
 | |
|                         if not isinstance(arg, Node):
 | |
|                             # Pull out constants. These constants will later be
 | |
|                             # fed to the interpreter C++ object via add_constant()
 | |
|                             arg_name = f'constant_{constant_idx}'
 | |
|                             constants[arg_name] = torch.tensor(
 | |
|                                 [arg] if isinstance(arg, numbers.Number) else arg)
 | |
|                             arg_names.append(arg_name)
 | |
|                             constant_idx += 1
 | |
|                         else:
 | |
|                             arg_names.append(arg.name)
 | |
|                     instructions.append((target_to_name[target], arg_names, out_name))
 | |
|                 elif n.op == 'output':
 | |
|                     if output_node is not None:
 | |
|                         raise RuntimeError('Multiple output nodes!')
 | |
|                     output_node = n
 | |
|                 else:
 | |
|                     raise RuntimeError('Unsupported opcode ' + n.op)
 | |
| 
 | |
|             interpreter = torch.classes._TorchScriptTesting._ElementwiseInterpreter()
 | |
|             # Load constants
 | |
|             for k, v in constants.items():
 | |
|                 interpreter.add_constant(k, v)
 | |
|             # Specify names for positional input arguments
 | |
|             interpreter.set_input_names(fn_input_names)
 | |
|             # Load instructions
 | |
|             interpreter.set_instructions(instructions)
 | |
|             # Specify name for single output
 | |
|             assert isinstance(output_node.args[0], torch.fx.Node)
 | |
|             interpreter.set_output_name(output_node.args[0].name)
 | |
| 
 | |
|             # ===== Stage 3: Create a wrapper GraphModule around the interpreter =====
 | |
|             class WrapperModule(torch.nn.Module):
 | |
|                 def __init__(self, interpreter):
 | |
|                     super().__init__()
 | |
|                     self.interpreter = interpreter
 | |
| 
 | |
|             wrapper = WrapperModule(interpreter)
 | |
| 
 | |
|             # Create a graph that: 1) Takes function arguments 2) Invokes the interpreter
 | |
|             # 3) Returns the speficied return value
 | |
| 
 | |
|             # FIXME: The following code could be greatly simplified by symbolic_trace'ing
 | |
|             # the wrapper with a Tracer that considers the Wrapper instance a root
 | |
|             # module, however, I can't get `__call__` exposed on TorchBind classes
 | |
|             # without it messing up Python `hasattr` for some reason. More digging
 | |
|             # into CPython's implementation of hasattr is probably in order...
 | |
| 
 | |
|             graph = torch.fx.Graph()
 | |
|             # Add placeholders for fn inputs
 | |
|             placeholder_nodes = []
 | |
|             for name in fn_input_names:
 | |
|                 placeholder_nodes.append(graph.create_node('placeholder', name))
 | |
| 
 | |
|             # Get the interpreter object
 | |
|             interpreter_node = graph.create_node('get_attr', 'interpreter')
 | |
| 
 | |
|             # Add a node to call the interpreter instance
 | |
|             output_node = graph.create_node(
 | |
|                 op='call_method', target='__call__', args=(interpreter_node, placeholder_nodes))
 | |
| 
 | |
|             # Register output
 | |
|             graph.output(output_node)
 | |
| 
 | |
|             graph.lint()
 | |
| 
 | |
|             # Return final GraphModule!!!
 | |
|             return GraphModule(wrapper, graph)
 | |
| 
 | |
| 
 | |
|         # Lower GraphModule to C++ interpreter
 | |
|         lowered = lower_to_elementwise_interpreter(msm)
 | |
| 
 | |
|         # Compare correctness with original module
 | |
|         x = torch.rand(3, 4)
 | |
|         ref_out = msm(x)
 | |
|         test_out = lowered(x)
 | |
|         torch.testing.assert_close(test_out, ref_out)
 | |
| 
 | |
|         # Test TorchScript compilation
 | |
|         scripted_lowered = torch.jit.script(lowered)
 | |
|         script_out = scripted_lowered(x)
 | |
|         torch.testing.assert_close(script_out, ref_out)
 | |
| 
 | |
|         # Test TorchScript ser/de
 | |
|         import_copy = self.getExportImportCopy(scripted_lowered)
 | |
|         imported_out = import_copy(x)
 | |
|         torch.testing.assert_close(imported_out, ref_out)
 | |
| 
 | |
|     def test_reserved_getattr(self):
 | |
|         """Ensure that we do not name any nodes with a reserved builtin like `getattr`"""
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a):
 | |
|                 return a.foo.bar.baz
 | |
| 
 | |
|         m = M()
 | |
|         m_g = symbolic_trace(m)
 | |
|         m_g.graph.lint()
 | |
|         for node in m_g.graph.nodes:
 | |
|             self.assertTrue(node.name != "getattr")
 | |
| 
 | |
|     def test_node_tagging(self):
 | |
|         class TaggingTracer(Tracer):
 | |
|             def create_node(self, kind : str, target : Union[str, Callable],
 | |
|                             args : Tuple[Argument, ...], kwargs : Dict[str, Any], name : Optional[str] = None,
 | |
|                             type_expr : Optional[Any] = None) -> Node:
 | |
|                 n = super().create_node(kind, target, args, kwargs, name)
 | |
|                 n.tag = 'foo'
 | |
|                 return n
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 return a + b
 | |
| 
 | |
|         m = M()
 | |
|         g = TaggingTracer().trace(m)
 | |
|         g.lint()
 | |
|         for n in g.nodes:
 | |
|             self.assertTrue(hasattr(n, 'tag'))
 | |
|             self.assertEqual(n.tag, 'foo')
 | |
| 
 | |
|     def test_tensor_attribute(self):
 | |
|         class TensorAttribute(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.tensor = torch.rand(3, 4)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return torch.nn.functional.linear(x, self.tensor)
 | |
| 
 | |
|         ta = TensorAttribute()
 | |
|         traced = symbolic_trace(ta)
 | |
|         traced(torch.rand(4, 4))
 | |
| 
 | |
|         class WrapperForQualname(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.ta = TensorAttribute()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return torch.nn.functional.linear(x, self.ta.tensor)
 | |
| 
 | |
|         wfq = WrapperForQualname()
 | |
|         traced2 = symbolic_trace(wfq)
 | |
|         traced2.graph.lint()
 | |
|         traced2(torch.rand(4, 4))
 | |
| 
 | |
|     def test_tensor_attribute_coalseced(self):
 | |
| 
 | |
|         def count_attrs(fx_module):
 | |
|             targets = set()
 | |
|             for node in traced.graph.nodes:
 | |
|                 if node.op == 'get_attr':
 | |
|                     targets.add(node.target)
 | |
|             return len(targets)
 | |
| 
 | |
|         val = torch.tensor(5)
 | |
| 
 | |
|         def f(x):
 | |
|             return x + val + val
 | |
|         traced = symbolic_trace(f)
 | |
|         traced.graph.lint()
 | |
|         self.assertEqual(count_attrs(traced), 1)
 | |
| 
 | |
|         val2 = torch.tensor(5)
 | |
| 
 | |
|         def f(x):
 | |
|             val = torch.tensor(5)
 | |
|             return x + val + val2
 | |
| 
 | |
|         traced = symbolic_trace(f)
 | |
|         traced.graph.lint()
 | |
|         self.assertEqual(count_attrs(traced), 2)
 | |
| 
 | |
| 
 | |
|     def test_symbolic_trace_sequential(self):
 | |
|         class Simple(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.neg(x)
 | |
| 
 | |
|         seq = torch.nn.Sequential(
 | |
|             Simple(),
 | |
|             Simple(),
 | |
|             Simple()
 | |
|         )
 | |
|         traced = symbolic_trace(seq)
 | |
|         traced.graph.lint()
 | |
|         x = torch.rand(3, 4)
 | |
|         self.assertEqual(traced(x), seq(x))
 | |
| 
 | |
|     def test_tensor_constant(self):
 | |
|         class ConstTensor(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.nn.functional.linear(x, torch.zeros(3, 4))
 | |
| 
 | |
|         ct = ConstTensor()
 | |
|         traced = symbolic_trace(ct)
 | |
|         traced.graph.lint()
 | |
|         traced(torch.rand(4, 4))
 | |
| 
 | |
|     def test_pickle_graphmodule(self):
 | |
|         class Nested(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.st = torch.nn.Linear(4, 4)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.st(x)
 | |
| 
 | |
|         n = Nested()
 | |
|         traced = symbolic_trace(n)
 | |
|         traced.graph.lint()
 | |
|         pickled = pickle.dumps(traced)
 | |
|         loaded = pickle.loads(pickled)
 | |
|         loaded.graph.lint()
 | |
|         x = torch.rand(3, 4)
 | |
|         self.assertEqual(loaded(x), traced(x))
 | |
| 
 | |
|     def test_pickle_custom_import(self):
 | |
|         graph = torch.fx.Graph()
 | |
|         a = graph.placeholder('x')
 | |
|         b = graph.placeholder('y')
 | |
|         c = graph.call_function(a_non_torch_leaf, (a, b))
 | |
|         d = graph.call_function(torch.sin, (c,))
 | |
|         graph.output(d)
 | |
|         gm = GraphModule(torch.nn.Module(), graph)
 | |
|         pickled = pickle.dumps(gm)
 | |
|         loaded = pickle.loads(pickled)
 | |
|         loaded.graph.lint()
 | |
|         x, y = torch.rand(1), torch.rand(1)
 | |
|         self.assertEqual(loaded(x, y), gm(x, y))
 | |
| 
 | |
|     def test_all_input_nodes(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         a : torch.fx.Node = graph.placeholder('x')
 | |
|         b : torch.fx.Node = graph.call_module('linear_mod', args=(a,))
 | |
|         c : torch.fx.Node = graph.get_attr('y_attr')
 | |
|         d : torch.fx.Node = graph.call_function(operator.add, args=(b, c))
 | |
|         e : torch.fx.Node = graph.call_function(torch.unsqueeze, args=(d, 0))
 | |
|         graph.output(e)
 | |
|         graph.lint()
 | |
| 
 | |
|         self.assertEqual(b.all_input_nodes, [a])
 | |
|         self.assertEqual(c.all_input_nodes, [])
 | |
|         self.assertEqual(d.all_input_nodes, [b, c])
 | |
|         self.assertEqual(e.all_input_nodes, [d])
 | |
| 
 | |
|     def test_deepcopy_graphmodule_with_transform(self):
 | |
|         st = SimpleTest()
 | |
|         traced = symbolic_trace(st)
 | |
|         traced.graph.lint()
 | |
| 
 | |
|         def transform(traced):
 | |
|             new_graph = torch.fx.Graph()
 | |
|             val_map : Dict[Node, Node] = {}
 | |
|             output_value = new_graph.graph_copy(traced.graph, val_map)
 | |
|             relu_out = new_graph.create_node(
 | |
|                 op='call_method', target='neg', args=(output_value,), kwargs={})
 | |
|             new_graph.output(relu_out)
 | |
|             return GraphModule(traced, new_graph)
 | |
|         transformed = transform(traced)
 | |
|         transformed.graph.lint()
 | |
|         copied = copy.deepcopy(transformed)
 | |
|         self.assertNotEqual(id(type(transformed)), id(type(copied)))
 | |
|         x = torch.randn(3, 4)
 | |
|         self.assertEqual(copied(x), transformed(x))
 | |
| 
 | |
|     def test_deepcopy_with_submods_params(self):
 | |
|         class Bar(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(3, 4))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return torch.relu(x) + self.param
 | |
| 
 | |
|         class Baz(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(3, 4))
 | |
|                 self.bar = Bar()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.bar(x) - self.param
 | |
| 
 | |
|         baz = Baz()
 | |
|         traced = symbolic_trace(baz)
 | |
|         traced.graph.lint()
 | |
|         copied = copy.deepcopy(traced)
 | |
|         copied.graph.lint()
 | |
| 
 | |
|     def test_deepcopy_graph_with_tracer_cls(self):
 | |
|         class TestTracer(Tracer):
 | |
|             def is_leaf_module(self, module, name):
 | |
|                 return True
 | |
| 
 | |
|         g = Graph(tracer_cls=TestTracer)
 | |
|         x = g.placeholder("x")
 | |
|         g.output(x)
 | |
| 
 | |
|         h = copy.deepcopy(g)
 | |
|         self.assertIsNotNone(h._tracer_cls)
 | |
|         self.assertTrue(g._tracer_cls == h._tracer_cls)
 | |
| 
 | |
|     def test_unpack_list_better_error(self):
 | |
|         class SomeArgs(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 return torch.rand(3, 4)
 | |
| 
 | |
|         class UnpacksList(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.sa = SomeArgs()
 | |
| 
 | |
|             def forward(self, x : list):
 | |
|                 return self.sa(*x)
 | |
| 
 | |
|         ul = UnpacksList()
 | |
|         with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
 | |
|             symbolic_trace(ul)
 | |
| 
 | |
|     def test_unpack_dict_better_error(self):
 | |
|         class SomeKwargs(torch.nn.Module):
 | |
|             def forward(self, x=3, y=4):
 | |
|                 return torch.rand(3, 4)
 | |
| 
 | |
|         class UnpacksDict(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.sk = SomeKwargs()
 | |
| 
 | |
|             def forward(self, x : dict):
 | |
|                 return self.sk(**x)
 | |
| 
 | |
|         ud = UnpacksDict()
 | |
|         with self.assertRaisesRegex(TraceError, 'Proxy object cannot be iterated.'):
 | |
|             symbolic_trace(ud)
 | |
| 
 | |
|     def test_pretty_print_targets(self):
 | |
|         # Test that Graph pretty-print prints friendly name for targets
 | |
|         # in `operator` and `builtins`
 | |
| 
 | |
|         class SomeMod(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.add(x.foo + x.bar, 3.0)
 | |
| 
 | |
|         traced = symbolic_trace(SomeMod())
 | |
|         graph_str = str(traced.graph)
 | |
|         self.assertIn('builtins.getattr', graph_str)
 | |
|         self.assertIn('operator.add', graph_str)
 | |
|         self.assertIn('torch.add', graph_str)
 | |
| 
 | |
|     def test_pretty_print_node(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.param: torch.nn.Parameter = torch.nn.Parameter(
 | |
|                     torch.rand(3, 4))
 | |
|                 self.linear = torch.nn.Linear(4, 5)
 | |
| 
 | |
|             def forward(self, x: torch.Tensor, y: int = 2):
 | |
|                 return self.linear(x[y] + self.param).clamp(min=0.0, max=1.0)
 | |
| 
 | |
|         traced = symbolic_trace(M())
 | |
| 
 | |
|         all_formatted = "\n".join([n.format_node() for n in traced.graph.nodes])
 | |
| 
 | |
|         FileCheck().check("x").check("placeholder") \
 | |
|             .check("y").check("placeholder") \
 | |
|             .check("getitem").check("call_function") \
 | |
|             .check("param").check("get_attr") \
 | |
|             .check("add").check("call_function") \
 | |
|             .check("linear").check("call_module") \
 | |
|             .check("clamp").check("call_method") \
 | |
|             .run(all_formatted)
 | |
| 
 | |
|     def test_script_tensor_constant(self):
 | |
|         # TorchScript seems to ignore attributes that start with `__`.
 | |
|         # We used to call anonymous Tensor values `__tensor_constant*`, but
 | |
|         # they were getting ignored by script. Now they're called
 | |
|         # `_tensor_constant*`
 | |
|         class IHaveATensorConstant(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return x + torch.rand(3, 4)
 | |
| 
 | |
|         traced = torch.fx.symbolic_trace(IHaveATensorConstant())
 | |
|         torch.jit.script(traced)
 | |
| 
 | |
|     def test_autowrap_functions(self):
 | |
|         class AutowrapFnTest(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return fx_int(x.shape[0] / 2)
 | |
| 
 | |
|         class AutowrapFnTest2(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return fx_int(x.shape[0] / 2) + fx_int_x2(x.shape[0] / 2)
 | |
| 
 | |
|         # Check function(s) are wrapped
 | |
|         # `int` would normally throw a TypeError as argument can't be `Proxy`
 | |
|         tracer = Tracer(autowrap_functions=(fx_int,))
 | |
|         graph = tracer.trace(AutowrapFnTest())
 | |
|         traced = GraphModule(tracer.root, graph, 'test')
 | |
|         tracer_2 = Tracer(autowrap_functions=(fx_int, fx_int_x2))
 | |
|         tracer_2.trace(AutowrapFnTest2())
 | |
| 
 | |
|         # Test scriptability
 | |
|         traced_scripted = torch.jit.script(traced)
 | |
|         self.assertEqual(traced_scripted(torch.rand(4)), 2)
 | |
| 
 | |
|     def test_torch_fx_len(self):
 | |
|         class FXLenTest(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return len(x)
 | |
| 
 | |
|         traced = symbolic_trace(FXLenTest())
 | |
|         self.assertEqual(traced(torch.rand(3, 4)), 3)
 | |
| 
 | |
|         # Test scriptability
 | |
|         scripted = torch.jit.script(FXLenTest())
 | |
|         self.assertEqual(scripted(torch.rand(3)), 3)
 | |
| 
 | |
|         traced_scripted = torch.jit.script(traced)
 | |
|         self.assertEqual(traced_scripted(torch.rand(3)), 3)
 | |
| 
 | |
|         # Test non-proxy len
 | |
|         class FXLenTest2(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.l = [3, 4, 5]
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return x + len(self.l)
 | |
| 
 | |
|         traced2 = symbolic_trace(FXLenTest2())
 | |
|         inp = torch.rand(3, 4)
 | |
|         self.assertEqual(traced2(inp), inp + 3.0)
 | |
|         self.assertIs(len, builtins.len)
 | |
| 
 | |
|     def test_torch_fx_getattr(self):
 | |
|         class FXGetattrTest(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return getattr(x, 'nonexistent_attr', torch.Tensor([2, 3]))
 | |
| 
 | |
|         traced = symbolic_trace(FXGetattrTest())
 | |
|         self.assertEqual(traced(torch.rand(3, 4)), torch.Tensor([2, 3]))
 | |
| 
 | |
|     def test_sqrt(self):
 | |
|         class Sqrt1(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return sqrt(x.size(0))
 | |
| 
 | |
|         class Sqrt2(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return math.sqrt(x.size(0))
 | |
| 
 | |
|         class Sqrt3(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return x + math.sqrt(2) + sqrt(2)
 | |
| 
 | |
|         self.checkGraphModule(Sqrt1(), [torch.zeros(8)])
 | |
|         self.checkGraphModule(Sqrt2(), [torch.zeros(8)])
 | |
|         self.checkGraphModule(Sqrt3(), [torch.zeros(8)])
 | |
|         self.assertIs(sqrt, _sqrt)
 | |
|         self.assertIs(math.sqrt, _sqrt)
 | |
| 
 | |
|     def test_torch_custom_ops(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a):
 | |
|                 b = torch.ops.aten.sigmoid(a)
 | |
|                 c = torch.ops.aten.cat([a, b])
 | |
|                 return torch.ops.aten.cat((c, c))
 | |
|         m = M()
 | |
|         input = torch.randn(3)
 | |
|         ref_out = m(input)
 | |
|         gm = symbolic_trace(m)
 | |
|         gm.graph.lint()
 | |
|         out = gm(input)
 | |
|         self.assertEqual(out, ref_out)
 | |
| 
 | |
|     def test_pickle_torch_custom_ops(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a):
 | |
|                 b = torch.ops.aten.sigmoid(a)
 | |
|                 c = torch.ops.aten.cat([a, b])
 | |
|                 return torch.ops.aten.cat((c, c))
 | |
|         m = M()
 | |
|         input = torch.randn(3)
 | |
|         ref_out = m(input)
 | |
|         gm = symbolic_trace(m)
 | |
|         gm.graph.lint()
 | |
|         pickled = pickle.dumps(gm)
 | |
|         loaded = pickle.loads(pickled)
 | |
|         self.assertEqual(loaded(input), gm(input))
 | |
| 
 | |
|     def test_pretty_print(self):
 | |
|         st = SimpleTest()
 | |
|         traced = symbolic_trace(st)
 | |
|         traced.graph.lint()
 | |
|         printed = str(traced)
 | |
|         assert 'SimpleTest()' in printed
 | |
|         assert 'torch.relu' in printed
 | |
| 
 | |
|     def test_pretty_print_graph(self):
 | |
|         class KwargPrintTest(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.squeeze(x + 3.0, dim=2)
 | |
|         st = KwargPrintTest()
 | |
|         traced = symbolic_trace(st)
 | |
|         traced.graph.lint()
 | |
|         stringed = str(traced.graph)
 | |
|         for s in ['args', 'kwargs', '#users']:
 | |
|             assert s in stringed
 | |
| 
 | |
|     def test_custom_proxy_type(self):
 | |
|         class TensorPair:
 | |
|             def __init__(self, left, right):
 | |
|                 self.left, self.right = left, right
 | |
| 
 | |
|             def add(self, other):
 | |
|                 l = self.left + other.left
 | |
|                 r = self.right + other.right
 | |
|                 return TensorPair(l, r)
 | |
| 
 | |
|             def mul(self, other):
 | |
|                 l = self.left * other.left
 | |
|                 r = self.right * other.right
 | |
|                 return TensorPair(l, r)
 | |
| 
 | |
|         def use_tensor_pair(x : TensorPair, y : TensorPair):
 | |
|             s = x.add(y)
 | |
|             return s.mul(x)
 | |
| 
 | |
|         x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
 | |
|         y = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
 | |
| 
 | |
|         ref_out = use_tensor_pair(x, y)
 | |
| 
 | |
|         traced = symbolic_trace(use_tensor_pair)
 | |
| 
 | |
|         traced_out = traced(x, y)
 | |
|         self.assertEqual(traced_out.left, ref_out.left)
 | |
|         self.assertEqual(traced_out.right, ref_out.right)
 | |
| 
 | |
|     def test_custom_proxy_type_literal(self):
 | |
|         class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
 | |
|             def __init__(self, left, right):
 | |
|                 self.left, self.right = left, right
 | |
| 
 | |
|             def add(self, other):
 | |
|                 l = self.left + other.left
 | |
|                 r = self.right + other.right
 | |
|                 return TensorPair(l, r)
 | |
| 
 | |
|             def mul(self, other):
 | |
|                 l = self.left * other.left
 | |
|                 r = self.right * other.right
 | |
|                 return TensorPair(l, r)
 | |
| 
 | |
|         def use_tensor_pair_literal(x : TensorPair):
 | |
|             s = x.add(TensorPair(torch.zeros(5, 3), torch.zeros(5, 3)))
 | |
|             return s.mul(x)
 | |
| 
 | |
|         x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
 | |
| 
 | |
|         ref_out = use_tensor_pair_literal(x)
 | |
| 
 | |
|         traced = symbolic_trace(use_tensor_pair_literal)
 | |
| 
 | |
|         traced_out = traced(x)
 | |
|         self.assertEqual(traced_out.left, ref_out.left)
 | |
|         self.assertEqual(traced_out.right, ref_out.right)
 | |
| 
 | |
|     def test_custom_proxy_dynamic_value(self):
 | |
|         class TensorPair(metaclass=torch.fx.ProxyableClassMeta):
 | |
|             def __init__(self, left, right):
 | |
|                 self.left, self.right = left, right
 | |
| 
 | |
|             def add(self, other):
 | |
|                 l = self.left + other.left
 | |
|                 r = self.right + other.right
 | |
|                 return TensorPair(l, r)
 | |
| 
 | |
|             def mul(self, other):
 | |
|                 l = self.left * other.left
 | |
|                 r = self.right * other.right
 | |
|                 return TensorPair(l, r)
 | |
| 
 | |
|         def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor):
 | |
|             s = x.add(TensorPair(y, y))
 | |
|             return s.mul(x)
 | |
| 
 | |
|         x = TensorPair(torch.randn(5, 3), torch.randn(5, 3))
 | |
|         y = torch.randn(5, 3)
 | |
|         ref_out = use_tensor_pair_ctor(x, y)
 | |
| 
 | |
|         traced = symbolic_trace(use_tensor_pair_ctor)
 | |
| 
 | |
|         traced_out = traced(x, y)
 | |
|         self.assertEqual(traced_out.left, ref_out.left)
 | |
|         self.assertEqual(traced_out.right, ref_out.right)
 | |
| 
 | |
|     def test_custom_proxy_input_dependent_control_flow(self):
 | |
|         class ZeroTensor(metaclass=torch.fx.ProxyableClassMeta):
 | |
|             def __init__(self, inp):
 | |
|                 if inp.sum() == 0:
 | |
|                     self.is_zero = True
 | |
|                     self.tensor = torch.tensor([])
 | |
|                 else:
 | |
|                     self.is_zero = False
 | |
|                     self.tensor = inp
 | |
| 
 | |
|             def add(self, other):
 | |
|                 if self.is_zero:
 | |
|                     return ZeroTensor(other.tensor)
 | |
|                 elif other.is_zero:
 | |
|                     return self
 | |
| 
 | |
|         def use_zero_tensor(x : torch.Tensor, y : torch.Tensor):
 | |
|             return ZeroTensor(x + y)
 | |
| 
 | |
|         x, y = torch.randn(5, 3), torch.randn(5, 3)
 | |
| 
 | |
|         ref_out = use_zero_tensor(x, y)
 | |
| 
 | |
|         traced = symbolic_trace(use_zero_tensor)
 | |
| 
 | |
|         traced_out = traced(x, y)
 | |
| 
 | |
|         self.assertEqual(traced_out.is_zero, ref_out.is_zero)
 | |
|         self.assertEqual(traced_out.tensor, ref_out.tensor)
 | |
| 
 | |
|     def test_graph_fns(self):
 | |
|         g = Graph()
 | |
|         a = g.placeholder('a')
 | |
|         b = g.call_module('linear', (a,))
 | |
|         c = g.get_attr('bias')
 | |
|         d = g.call_method('add', (b, c))
 | |
|         e = g.call_function(torch.sin, (d,))
 | |
|         g.output(e)
 | |
|         mod = torch.nn.Module()
 | |
|         mod.linear = torch.nn.Linear(3, 4)
 | |
|         mod.bias = torch.rand(4)
 | |
|         gm = GraphModule(mod, g)
 | |
|         gm.graph.lint()
 | |
|         input = torch.rand(3)
 | |
|         r = gm(input)
 | |
|         ref = torch.sin(mod.linear(input) + mod.bias)
 | |
|         self.assertEqual(r, ref)
 | |
| 
 | |
|     def test_remove_uses(self):
 | |
|         g : torch.fx.Graph = Graph()
 | |
|         x : torch.fx.Node = g.placeholder('x')
 | |
|         relu : torch.fx.Node = g.call_function(torch.relu, (x,))
 | |
|         neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
 | |
|         g.output(neg)
 | |
| 
 | |
|         neg.replace_all_uses_with(relu)
 | |
|         g.erase_node(neg)
 | |
| 
 | |
|         self.assertTrue(neg not in relu.users)
 | |
| 
 | |
|     def test_nonetype_annotation(self):
 | |
|         eb = torch.nn.EmbeddingBag(3, 4)
 | |
|         symbolic_trace(eb)
 | |
| 
 | |
|     def test_pickle_nonetype_annotation(self):
 | |
|         eb = torch.nn.EmbeddingBag(10, 3, mode='sum')
 | |
|         traced = symbolic_trace(eb)
 | |
|         pickled = pickle.dumps(traced)
 | |
|         loaded = pickle.loads(pickled)
 | |
|         loaded.graph.lint()
 | |
|         input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
 | |
|         offsets = torch.LongTensor([0, 4])
 | |
|         self.assertEqual(loaded(input, offsets), traced(input, offsets))
 | |
| 
 | |
|     def test_return_tuple(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|                 return (x, x + x)
 | |
| 
 | |
| 
 | |
|         original = M()
 | |
|         traced = symbolic_trace(original)
 | |
|         self.assertEqual(traced(torch.ones(1)), original.forward(torch.ones(1)))
 | |
| 
 | |
|     def test_construct_root_dict(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         a : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
 | |
|         c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
 | |
|         d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
 | |
|         graph.output(d)
 | |
| 
 | |
|         linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
 | |
|         add_param : torch.Tensor = torch.rand(3, 4)
 | |
|         gm : torch.fx.GraphModule = torch.fx.GraphModule(
 | |
|             {'foo.bar.baz': linear_mod, 'zip.zap.zam' : add_param}, graph)
 | |
|         gm.graph.lint()
 | |
| 
 | |
|         assert 'self.foo.bar.baz' in gm.code
 | |
| 
 | |
|         x : torch.Tensor = torch.rand(3, 3)
 | |
|         out : torch.Tensor = gm(x)
 | |
|         ref_out : torch.Tensor = linear_mod(x) + add_param
 | |
|         self.assertEqual(out, ref_out)
 | |
| 
 | |
|     def test_symbolic_trace_assert(self):
 | |
| 
 | |
|         class AssertsTensorShape(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 torch._assert(x.shape[1] > 4, "assert_foobar")
 | |
|                 return x
 | |
| 
 | |
|         m = AssertsTensorShape()
 | |
|         # verify traceability
 | |
|         traced = symbolic_trace(m)
 | |
|         # verify assertion on traced model works correctly at runtime
 | |
|         traced(torch.rand(4, 5))
 | |
|         with self.assertRaisesRegex(AssertionError, "assert_foobar"):
 | |
|             traced(torch.rand(4, 3))
 | |
|         # verify the symbolically traced module is scriptable
 | |
|         ms = torch.jit.script(m)
 | |
|         with self.assertRaisesRegex(torch.jit.Error, "assert_foobar"):
 | |
|             ms(torch.rand(4, 3))
 | |
| 
 | |
|     def test_fx_create_arg(self):
 | |
|         class CustomArgObject:
 | |
|             def __init__(self, x, y):
 | |
|                 self.x = x
 | |
|                 self.y = y
 | |
| 
 | |
|             def __fx_create_arg__(self, tracer: torch.fx.Tracer):
 | |
|                 return tracer.create_node(
 | |
|                     "call_function",
 | |
|                     CustomArgObject,
 | |
|                     args=(
 | |
|                         tracer.create_arg(self.x),
 | |
|                         tracer.create_arg(self.y),
 | |
|                     ),
 | |
|                     kwargs={},
 | |
|                 )
 | |
| 
 | |
|         class HasCustomArgObjectWhenLeaf(torch.nn.Module):
 | |
|             def forward(self, o: CustomArgObject):
 | |
|                 # Not normally traceable; good reason to make
 | |
|                 # this module a leaf.
 | |
|                 for x in o.x:
 | |
|                     o.y += x
 | |
|                 return o.y
 | |
| 
 | |
|         class Root(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.inner = HasCustomArgObjectWhenLeaf()
 | |
| 
 | |
|             def forward(self, x, y):
 | |
|                 o = CustomArgObject(x, y)
 | |
|                 return self.inner(o)
 | |
| 
 | |
|         class CreateArgTracer(torch.fx.Tracer):
 | |
|             def is_leaf_module(self, m, module_qualified_name):
 | |
|                 return type(m) is HasCustomArgObjectWhenLeaf
 | |
| 
 | |
|         m = Root()
 | |
|         graph = CreateArgTracer().trace(m)
 | |
|         gm = torch.fx.GraphModule(m, graph)
 | |
|         assert "CustomArgObject(" in gm.code
 | |
| 
 | |
|     def test_trace_fn_constant(self):
 | |
|         some_constant = torch.rand(3, 4)
 | |
| 
 | |
|         def add_const(x):
 | |
|             return some_constant + x
 | |
| 
 | |
|         traced = symbolic_trace(add_const)
 | |
| 
 | |
|         input = torch.rand(3, 4)
 | |
|         self.assertEqual(traced(input), add_const(input))
 | |
| 
 | |
|     def test_copy_no_remap(self):
 | |
|         traced = symbolic_trace(SimpleTest())
 | |
|         g = traced.graph
 | |
|         copied = torch.fx.Graph()
 | |
|         for node in g.nodes:
 | |
|             copied.node_copy(node)
 | |
|         with self.assertRaisesRegex(RuntimeError, 'does not belong to this Graph'):
 | |
|             copied.lint()
 | |
| 
 | |
|     def test_wrong_topo(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         a : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
 | |
|         c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
 | |
|         d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
 | |
|         graph.output(d)
 | |
|         nodes = list(graph.nodes)
 | |
|         nodes[3].append(nodes[2])
 | |
|         with self.assertRaisesRegex(RuntimeError, 'was used before it has been defined'):
 | |
|             graph.lint()
 | |
| 
 | |
|     def test_wrong_target_type(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         with self.assertRaises(ValueError):
 | |
|             n = torch.fx.Node(graph=graph, name='foo', op='call_function', target='foo',
 | |
|                               args=(), kwargs={})
 | |
| 
 | |
|     def test_example_shape_prop(self):
 | |
|         class TestCase(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.attr = torch.randn(3, 4)
 | |
|                 self.submod = torch.nn.Linear(4, 4)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return torch.neg(self.submod(x.relu() + self.attr))
 | |
|         tc = TestCase()
 | |
|         tc_traced = symbolic_trace(tc)
 | |
|         ref_out = tc_traced(torch.rand(3, 4))
 | |
|         shape_prop.ShapeProp(tc_traced).propagate(torch.rand(3, 4))
 | |
| 
 | |
|         # Make sure we're testing all opcodes
 | |
|         opcodes = set()
 | |
|         output_shape : Optional[torch.Shape] = None
 | |
|         output_stride : Optional[Tuple[int]] = None
 | |
|         for node in tc_traced.graph.nodes:
 | |
|             opcodes.add(node.op)
 | |
|             if node.op == 'output':
 | |
|                 output_shape = node.args[0].meta['tensor_meta'].shape
 | |
|                 output_stride = node.args[0].meta['tensor_meta'].stride
 | |
|         self.assertEqual(opcodes, set(['placeholder', 'get_attr', 'call_function', 'call_method',
 | |
|                                        'call_module', 'output']))
 | |
| 
 | |
|         # Test shape propogation and make sure results match actual
 | |
|         self.assertEqual(output_shape, ref_out.shape)
 | |
|         self.assertEqual(output_stride, ref_out.stride())
 | |
| 
 | |
|     def test_shape_prop_layout(self):
 | |
|         class ConvTest(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.conv_mod = torch.nn.Conv2d(5, 5, 3)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.conv_mod(x)
 | |
| 
 | |
|         # contiguous layout
 | |
|         test_mod = ConvTest()
 | |
|         traced = symbolic_trace(test_mod)
 | |
|         x = torch.randn(5, 5, 224, 224)
 | |
|         shape_prop.ShapeProp(traced).propagate(x)
 | |
| 
 | |
|         assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
 | |
|                    for node in traced.graph.nodes))
 | |
| 
 | |
|         x_channels_last = x.contiguous(memory_format=torch.channels_last)
 | |
|         traced.to(memory_format=torch.channels_last)
 | |
|         shape_prop.ShapeProp(traced).propagate(x_channels_last)
 | |
|         for node in traced.graph.nodes:
 | |
|             # NB: the implementation of conv may not preserve the memory format,
 | |
|             # unfortunately. The best we can do is just check that the placeholder
 | |
|             # node is channels-last
 | |
|             if node.op in {'placeholder'}:
 | |
|                 self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last)
 | |
| 
 | |
|     def test_shape_prop_aggregate(self):
 | |
|         class ReturnTwo(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return (3, torch.sum(x))
 | |
| 
 | |
|         class UnderTest(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.rt = ReturnTwo()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.rt(x)
 | |
| 
 | |
|         ut = UnderTest()
 | |
| 
 | |
|         class RTTracer(torch.fx.Tracer):
 | |
|             def is_leaf_module(self, m, module_qualified_name):
 | |
|                 return type(m) is ReturnTwo
 | |
| 
 | |
|         graph = RTTracer().trace(ut)
 | |
|         mod = torch.fx.GraphModule(ut, graph)
 | |
| 
 | |
|         shape_prop.ShapeProp(mod).propagate(torch.rand(3, 4))
 | |
| 
 | |
|         for node in mod.graph.nodes:
 | |
|             if node.op == 'call_module':
 | |
|                 assert 'tensor_meta' in node.meta
 | |
|                 tensor_meta = node.meta['tensor_meta']
 | |
|                 assert tensor_meta[0] == 3
 | |
|                 assert tensor_meta[1].shape == torch.Size([])
 | |
| 
 | |
|     def test_shape_prop_layout_3d(self):
 | |
|         class ConvTest3d(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.conv_mod = torch.nn.Conv3d(5, 5, 3)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.conv_mod(x)
 | |
| 
 | |
|         test_mod_3d = ConvTest3d()
 | |
|         traced_3d = symbolic_trace(test_mod_3d)
 | |
|         x_3d = torch.randn(5, 5, 224, 224, 15)
 | |
|         shape_prop.ShapeProp(traced_3d).propagate(x_3d)
 | |
|         assert(all(node.meta['tensor_meta'].memory_format is torch.contiguous_format
 | |
|                    for node in traced_3d.graph.nodes))
 | |
| 
 | |
|         x_channels_last_3d = x_3d.contiguous(memory_format=torch.channels_last_3d)
 | |
|         traced_3d.to(memory_format=torch.channels_last_3d)
 | |
|         shape_prop.ShapeProp(traced_3d).propagate(x_channels_last_3d)
 | |
|         for node in traced_3d.graph.nodes:
 | |
|             # NB: the implementation of conv may not preserve the memory format,
 | |
|             # unfortunately. The best we can do is just check that the placeholder
 | |
|             # node is channels-last
 | |
|             if node.op in {'placeholder'}:
 | |
|                 self.assertEqual(node.meta['tensor_meta'].memory_format, torch.channels_last_3d)
 | |
| 
 | |
|     def test_interpreter(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(3, 4))
 | |
|                 self.linear = torch.nn.Linear(4, 5)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.linear(x + self.param).clamp(min=0.0, max=1.0)
 | |
| 
 | |
|         m = MyModule()
 | |
|         gm = torch.fx.symbolic_trace(m)
 | |
| 
 | |
|         interpreter = Interpreter(gm)
 | |
|         input = torch.randn(3, 4)
 | |
|         self.assertEqual(interpreter.run(input), gm(input))
 | |
|         self.assertEqual(interpreter.run(input), m(input))
 | |
| 
 | |
|     def test_interpreter_run_node_override(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(3, 4))
 | |
|                 self.linear = torch.nn.Linear(4, 5)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.linear(x + self.param).clamp(min=0.0, max=1.0)
 | |
| 
 | |
|         m = MyModule()
 | |
|         gm = torch.fx.symbolic_trace(m)
 | |
| 
 | |
|         class RunNodeInterpreter(Interpreter):
 | |
|             def __init__(self, module):
 | |
|                 super().__init__(module)
 | |
| 
 | |
|             def run_node(self, n : Node) -> Any:
 | |
|                 result = super().run_node(n)
 | |
|                 n.cached_value = result
 | |
|                 return result
 | |
| 
 | |
|         input = torch.randn(3, 4)
 | |
|         RunNodeInterpreter(gm).run(input)
 | |
|         for node in gm.graph.nodes:
 | |
|             assert hasattr(node, 'cached_value')
 | |
| 
 | |
|     def test_interpreter_onthefly_swap(self):
 | |
| 
 | |
|         def fn(x):
 | |
|             return torch.sigmoid(x).neg()
 | |
| 
 | |
|         gm = torch.fx.symbolic_trace(fn)
 | |
| 
 | |
|         class NegSigmSwapInterpreter(Interpreter):
 | |
|             def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
 | |
|                 if target == torch.sigmoid:
 | |
|                     return torch.neg(*args, **kwargs)
 | |
|                 return super().call_function(n)
 | |
| 
 | |
|             def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
 | |
|                 if target == 'neg':
 | |
|                     call_self, *args_tail = args
 | |
|                     return call_self.sigmoid(*args_tail, **kwargs)
 | |
|                 return super().call_method(n)
 | |
| 
 | |
|         input = torch.randn(3, 4)
 | |
|         result = NegSigmSwapInterpreter(gm).run(input)
 | |
|         self.assertEqual(result, torch.neg(input).sigmoid())
 | |
| 
 | |
|     def test_interpreter_partial_eval(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(3, 4))
 | |
|                 self.linear = torch.nn.Linear(4, 5)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.linear(x + self.param).clamp(min=0.0, max=1.0)
 | |
| 
 | |
|         gm = torch.fx.symbolic_trace(MyModule())
 | |
|         interp = Interpreter(gm)
 | |
|         env = {}
 | |
|         for node in gm.graph.nodes:
 | |
|             if node.op == 'call_module' and node.target == 'linear':
 | |
|                 env[node] = torch.arange(0, 12, 1).reshape(3, 4) - 6.0
 | |
|                 break
 | |
|         assert len(env) == 1
 | |
|         x = torch.randn(3, 4)
 | |
|         result = interp.run(x, initial_env=env)
 | |
|         self.assertEqual(result, (torch.arange(0, 12, 1).reshape(3, 4) - 6.0).clamp(0.0, 1.0))
 | |
| 
 | |
|     def test_interpreter_star_args(self):
 | |
|         def with_star_args(x, *args):
 | |
|             return x + args[0]
 | |
| 
 | |
|         gm = torch.fx.symbolic_trace(with_star_args)
 | |
|         interp = Interpreter(gm)
 | |
|         result = interp.run(torch.ones(3, 4), torch.ones(3, 4), torch.rand(3, 4))
 | |
|         self.assertEqual(result, torch.ones(3, 4) * 2.0)
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     def test_interpreter_noop_resnet18(self):
 | |
|         rn18 = torchvision_models.resnet18()
 | |
|         transformed = torch.fx.Transformer(symbolic_trace(rn18)).transform()
 | |
|         inp = torch.randn(5, 3, 224, 224)
 | |
|         self.assertEqual(transformed(inp), rn18(inp))
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     def test_interpreter_gc_values(self):
 | |
|         rn18 = torchvision_models.resnet18()
 | |
|         interp = Interpreter(symbolic_trace(rn18))
 | |
|         inp = torch.rand(5, 3, 224, 224)
 | |
|         out = interp.run(inp)
 | |
|         env_key_names = set(n.name for n in interp.env.keys())
 | |
|         self.assertEqual(env_key_names, set(['output']))
 | |
| 
 | |
|     def test_transformer_noop(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(3, 4))
 | |
|                 self.linear = torch.nn.Linear(4, 5)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.linear(x + self.param).clamp(min=0.0, max=1.0)
 | |
| 
 | |
|         m = MyModule()
 | |
|         gm = torch.fx.symbolic_trace(m)
 | |
| 
 | |
|         new_gm = Transformer(gm).transform()
 | |
| 
 | |
|         input = torch.randn(3, 4)
 | |
|         self.assertEqual(new_gm(input), gm(input))
 | |
| 
 | |
|     def test_transformer_op_swap(self):
 | |
| 
 | |
|         def fn(x):
 | |
|             return torch.sigmoid(x).neg()
 | |
| 
 | |
|         gm = torch.fx.symbolic_trace(fn)
 | |
| 
 | |
|         class NegSigmSwapXformer(Transformer):
 | |
|             def call_function(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
 | |
|                 if target == torch.sigmoid:
 | |
|                     return torch.neg(*args, **kwargs)
 | |
|                 return super().call_function(n)
 | |
| 
 | |
|             def call_method(self, target : Target, args : Tuple, kwargs : Dict) -> Any:
 | |
|                 if target == 'neg':
 | |
|                     call_self, *args_tail = args
 | |
|                     return call_self.sigmoid(*args_tail, **kwargs)
 | |
|                 return super().call_method(n)
 | |
| 
 | |
|         transformed = NegSigmSwapXformer(gm).transform()
 | |
|         input = torch.randn(3, 4)
 | |
|         self.assertEqual(transformed(input), torch.neg(input).sigmoid())
 | |
| 
 | |
|     def test_transformer_multi_outputs(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(3, 4))
 | |
|                 self.linear = torch.nn.Linear(4, 5)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = x + self.param
 | |
|                 out = self.linear(x)
 | |
|                 return x, out
 | |
| 
 | |
|         m = MyModule()
 | |
|         gm = torch.fx.symbolic_trace(m)
 | |
| 
 | |
|         new_gm = Transformer(gm).transform()
 | |
| 
 | |
|         input = torch.randn(3, 4)
 | |
|         self.assertEqual(new_gm(input), gm(input))
 | |
| 
 | |
|     def test_fn_type_annotations(self):
 | |
|         class Foo(torch.nn.Module):
 | |
|             def forward(self, p : Pair, z : torch.Tensor, i : int) -> Dict[str, torch.Tensor]:
 | |
|                 return {'a': p.x + p.y + z + i}
 | |
| 
 | |
|         foo_scripted = torch.jit.script(Foo())
 | |
|         foo_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
 | |
| 
 | |
|         fxed = symbolic_trace(Foo())
 | |
|         fxed_scripted = torch.jit.script(fxed)
 | |
|         fxed_scripted(Pair(torch.rand(5), torch.rand(5)), torch.rand(5), 3)
 | |
| 
 | |
|     def test_fn_type_annotation_empty(self):
 | |
|         def forward(a : List[torch.Tensor]):
 | |
|             return a[0]
 | |
|         torch.jit.script(symbolic_trace(forward))
 | |
| 
 | |
|     def test_wrapped_method(self):
 | |
|         def wrap_with_relu(fn):
 | |
|             @functools.wraps(fn)
 | |
|             def wrapper(*args, **kwargs):
 | |
|                 return torch.relu(fn(*args, **kwargs))
 | |
|             return wrapper
 | |
| 
 | |
|         class Foo(torch.nn.Module):
 | |
|             @wrap_with_relu
 | |
|             def forward(self, x, w):
 | |
|                 return torch.matmul(x, w)
 | |
| 
 | |
|         f = Foo()
 | |
|         traced = symbolic_trace(f)
 | |
|         x, w = torch.rand(3, 4), torch.rand(4, 4)
 | |
|         self.assertTrue(any(n.target == torch.relu for n in traced.graph.nodes))
 | |
| 
 | |
|     def test_empty_graph_codegen(self):
 | |
|         graph = torch.fx.Graph()
 | |
|         gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
|         self.assertEqual(gm(), None)
 | |
| 
 | |
|     def test_sequential(self):
 | |
|         m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
 | |
|         gm = torch.fx.symbolic_trace(m)
 | |
|         gm_copy = copy.deepcopy(gm)
 | |
| 
 | |
|     def test_ctx_mgr(self):
 | |
|         @contextlib.contextmanager
 | |
|         def do_nothing():
 | |
|             yield
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
| 
 | |
|             @do_nothing()
 | |
|             def forward(self, x):
 | |
|                 return torch.relu(x)
 | |
| 
 | |
|         m = M()
 | |
|         self.checkGraphModule(m, (torch.rand(3, 4),))
 | |
| 
 | |
|     def test_typename_print(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         x : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,),
 | |
|                                               type_expr=List[float])
 | |
|         output : torch.fx.Node = graph.output(b)
 | |
| 
 | |
|         self.assertTrue('typing.List[float]' in str(graph))
 | |
| 
 | |
|     def test_layout(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return torch.empty_like(x, layout=torch.strided, pin_memory=False).fill_(0)
 | |
| 
 | |
|         traced = symbolic_trace(M())
 | |
|         x = torch.rand(5, 9, 3, 4)
 | |
|         self.assertEqual(traced(x), torch.zeros_like(x))
 | |
| 
 | |
|     def test_ellipsis(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
| 
 | |
|             def forward(self, x, y):
 | |
|                 return x + y[:, 1:10, ...]
 | |
| 
 | |
|         traced = symbolic_trace(M())
 | |
|         x, y = torch.rand(5, 9, 3, 4), torch.rand(5, 15, 3, 4)
 | |
|         self.assertEqual(traced(x, y), x + y[:, 1:10, ...])
 | |
| 
 | |
|     def test_inf_nan(self):
 | |
|         class FooMod(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return x + float('inf'), x + float('-inf'), x + float('nan')
 | |
| 
 | |
|         fm = FooMod()
 | |
|         self.checkGraphModule(fm, (torch.rand(3, 4),))
 | |
| 
 | |
|     def test_inf_nan_kwds(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         x : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         b : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('inf')), {}, name='inf')
 | |
|         c : torch.fx.Node = graph.create_node('call_function', operator.add, (x, float('nan')), {}, name='nan')
 | |
|         graph.output((b, c))
 | |
| 
 | |
|         gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
|         x = torch.rand(3, 4)
 | |
|         self.assertEqual(gm(x), (x + float('inf'), x + float('nan')))
 | |
| 
 | |
|     def test_deepcopy_recursion_depth(self):
 | |
|         depth = sys.getrecursionlimit() + 20
 | |
| 
 | |
|         g = torch.fx.Graph()
 | |
|         x = g.placeholder('x')
 | |
|         for i in range(depth):
 | |
|             x = g.call_function(torch.relu, (x,))
 | |
|         g.output(x)
 | |
| 
 | |
|         copied_graph = copy.deepcopy(g)
 | |
| 
 | |
|         val_map = {}
 | |
|         for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
 | |
|             val_map[orig_node] = new_node
 | |
| 
 | |
|         for orig_node, new_node in zip(g.nodes, copied_graph.nodes):
 | |
|             orig_users = set(orig_node.users.keys())
 | |
|             orig_users_equiv = set(val_map[u] for u in orig_users)
 | |
|             new_users = set(new_node.users.keys())
 | |
|             self.assertEqual(orig_users_equiv, new_users)
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     def test_replace_uses(self):
 | |
|         rn18 = torchvision_models.resnet18()
 | |
| 
 | |
|         class LowerReluTracer(torch.fx.Tracer):
 | |
|             def is_leaf_module(self, m : torch.nn.Module, qualname : str):
 | |
|                 if isinstance(m, torch.nn.ReLU):
 | |
|                     return False
 | |
|                 return super().is_leaf_module(m, qualname)
 | |
| 
 | |
|         rn18_traced = GraphModule(rn18, LowerReluTracer().trace(rn18))
 | |
| 
 | |
|         to_erase = []
 | |
|         for node in rn18_traced.graph.nodes:
 | |
|             if node.op == 'call_function' and node.target in [torch.relu, torch.nn.functional.relu]:
 | |
|                 kwargs = node.kwargs.copy()
 | |
|                 # Neg doesn't have in-place
 | |
|                 kwargs.pop('inplace')
 | |
|                 with rn18_traced.graph.inserting_before(node):
 | |
|                     new_node = rn18_traced.graph.call_function(
 | |
|                         the_function=torch.neg, args=node.args, kwargs=node.kwargs)
 | |
|                 node.replace_all_uses_with(replace_with=new_node)
 | |
|                 to_erase.append(node)
 | |
| 
 | |
|         for node in to_erase:
 | |
|             rn18_traced.graph.erase_node(node)
 | |
| 
 | |
| 
 | |
|     def test_replace_input(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         x : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         y : torch.fx.Node = graph.create_node('placeholder', 'y')
 | |
|         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
 | |
|         output : torch.fx.Node = graph.output(b)
 | |
| 
 | |
|         b.replace_input_with(x, y)
 | |
| 
 | |
|         gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
| 
 | |
|         input_x = torch.randn(33, 44)
 | |
|         input_y = torch.randn(11, 22)
 | |
|         self.assertEqual(gm(input_x, input_y), torch.relu(input_y))
 | |
| 
 | |
|     def test_insertion_point(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         x : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
 | |
|         output : torch.fx.Node = graph.output(b)
 | |
| 
 | |
|         with graph.inserting_before(b):
 | |
|             neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
 | |
|             _, *relu_args = b.args
 | |
|             b.args = (neg, *relu_args)
 | |
| 
 | |
|         gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
| 
 | |
|         input = torch.randn(33, 44)
 | |
|         self.assertEqual(gm(input), torch.relu(torch.neg(input)))
 | |
| 
 | |
|     def test_update_args_api(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         x : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         y : torch.fx.Node = graph.create_node('placeholder', 'y')
 | |
|         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
 | |
|         output : torch.fx.Node = graph.output(b)
 | |
| 
 | |
|         orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
|         inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
 | |
|         self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
 | |
| 
 | |
| 
 | |
|         b.update_arg(0, y)
 | |
|         new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
|         self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
 | |
| 
 | |
|     def test_update_kwargs_api(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         x : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         y : torch.fx.Node = graph.create_node('placeholder', 'y')
 | |
|         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, kwargs={'input': x})
 | |
|         output : torch.fx.Node = graph.output(b)
 | |
| 
 | |
|         orig_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
|         inp_x, inp_y = torch.randn(5, 3), torch.randn(3, 5)
 | |
|         self.assertEqual(orig_gm(inp_x, inp_y), torch.relu(inp_x))
 | |
| 
 | |
| 
 | |
|         b.update_kwarg('input', y)
 | |
|         new_gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
|         self.assertEqual(new_gm(inp_x, inp_y), torch.relu(inp_y))
 | |
| 
 | |
|     def test_move_before(self):
 | |
|         graph : torch.fx.Graph = torch.fx.Graph()
 | |
|         x : torch.fx.Node = graph.create_node('placeholder', 'x')
 | |
|         b : torch.fx.Node = graph.create_node('call_function', target=torch.relu, args=(x,))
 | |
|         output : torch.fx.Node = graph.output(b)
 | |
| 
 | |
|         neg : torch.fx.Node = graph.call_function(the_function=torch.neg, args=(x,))
 | |
|         _, *relu_args = b.args
 | |
|         b.args = (neg, *relu_args)
 | |
|         b.prepend(neg)
 | |
| 
 | |
|         gm = torch.fx.GraphModule(torch.nn.Module(), graph)
 | |
| 
 | |
|         input = torch.randn(33, 44)
 | |
|         self.assertEqual(gm(input), torch.relu(torch.neg(input)))
 | |
| 
 | |
|     def test_erase_node_error(self):
 | |
|         st = SimpleTest()
 | |
|         traced = symbolic_trace(st)
 | |
| 
 | |
|         for node in traced.graph.nodes:
 | |
|             # Test deleting with uses both in another Node and at the output
 | |
|             if node.target in [operator.add, torch.relu]:
 | |
|                 with self.assertRaisesRegex(RuntimeError, 'but it still had .* users in the graph'):
 | |
|                     traced.graph.erase_node(node)
 | |
| 
 | |
|     def test_copy_it(self):
 | |
|         d = immutable_dict([(3, 4), (5, 6)])
 | |
|         l = immutable_list([(3, 4), (5, 6)])
 | |
| 
 | |
|         self.assertEqual(d, deepcopy(d))
 | |
|         self.assertEqual(l, deepcopy(l))
 | |
| 
 | |
|     def test_get_torch_func_signature(self):
 | |
|         for key in dir(torch):
 | |
|             obj = getattr(torch, key)
 | |
|             if callable(obj):
 | |
|                 schemas = get_signature_for_torch_op(obj)
 | |
| 
 | |
|     def test_find_uses(self):
 | |
|         graph = torch.fx.Graph()
 | |
|         x = torch.fx.Proxy(graph.placeholder('x'))
 | |
| 
 | |
|         y = torch.relu(x)
 | |
|         z = x + x
 | |
|         u = torch.neg(x)
 | |
|         graph.output((y + z + u).node)
 | |
|         graph.lint()
 | |
| 
 | |
|         users_of_x = x.node.users
 | |
|         self.assertEqual(len(users_of_x), 3)
 | |
|         expected_ops = set(['relu', 'add', 'neg'])
 | |
|         for use in users_of_x:
 | |
|             assert any(use.name.startswith(prefix) for prefix in expected_ops)
 | |
| 
 | |
|     def test_inline_graph(self):
 | |
|         class InlineInto(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.relu(x)
 | |
| 
 | |
|         class ToInline(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.neg(x)
 | |
| 
 | |
|         inline_into = symbolic_trace(InlineInto())
 | |
|         to_inline = symbolic_trace(ToInline())
 | |
| 
 | |
|         combined_graph = torch.fx.Graph()
 | |
|         output_node = combined_graph.graph_copy(inline_into.graph, {})
 | |
| 
 | |
|         input_node = list(to_inline.graph.nodes)[0]
 | |
|         assert input_node and input_node.op == 'placeholder'
 | |
| 
 | |
|         val_map = {input_node : output_node}
 | |
|         output = combined_graph.graph_copy(to_inline.graph, val_map)
 | |
|         combined_graph.output(output)
 | |
| 
 | |
|         combined_module = torch.fx.GraphModule(torch.nn.Module(), combined_graph)
 | |
| 
 | |
|         input = torch.rand(3, 4)
 | |
|         self.assertEqual(combined_module(input), input.relu().neg())
 | |
| 
 | |
|     def test_multi_insert_point(self):
 | |
|         graph = torch.fx.Graph()
 | |
|         x = torch.fx.Proxy(graph.placeholder('x'))
 | |
|         relu = torch.relu(x)
 | |
| 
 | |
|         with graph.inserting_before(relu.node):
 | |
|             y = torch.neg(x)
 | |
|             z = torch.tanh(y)
 | |
| 
 | |
|         graph.output((relu.node, z.node))
 | |
|         graph.lint()
 | |
| 
 | |
|         expected_ops = ['x', 'neg', 'tanh', 'relu']
 | |
|         for node, expected in zip(graph.nodes, expected_ops):
 | |
|             assert expected in node.name
 | |
| 
 | |
|     def test_reassign_args_kwargs_uses(self):
 | |
|         graph = torch.fx.Graph()
 | |
|         x, y = Proxy(graph.placeholder('x')), Proxy(graph.placeholder('y'))
 | |
|         z = x + y
 | |
|         zed = z + z + z
 | |
|         graph.output(zed.node)
 | |
|         graph.lint()
 | |
| 
 | |
|         # zed = z + z + z -> zed = z + z + x
 | |
|         zed.node.args = (zed.node.args[0], x.node)
 | |
|         self.assertEqual(x.node.users.keys(), [z.node, zed.node])
 | |
| 
 | |
|         # z = x + y -> z = y + y
 | |
|         z.node.args = (y.node, y.node)
 | |
|         self.assertEqual(x.node.users.keys(), [zed.node])
 | |
| 
 | |
|     def test_trace_function(self):
 | |
|         def foo(x, y):
 | |
|             return torch.relu(x) + y
 | |
| 
 | |
|         x, y = torch.randn(3, 4), torch.randn(3, 4)
 | |
|         self.checkGraphModule(foo, (x, y))
 | |
| 
 | |
|     def test_trace_dict_int_keys(self):
 | |
|         class ModWithDictArg(torch.nn.Module):
 | |
|             def forward(self, d : Dict[int, torch.Tensor]):
 | |
|                 return d[42]
 | |
| 
 | |
|         class CallsModWithDict(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.m = ModWithDictArg()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.m({42: x})
 | |
| 
 | |
|         class MyTracer(torch.fx.Tracer):
 | |
|             def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
 | |
|                 return isinstance(m, ModWithDictArg)
 | |
| 
 | |
|         traced_graph = MyTracer().trace(CallsModWithDict())
 | |
| 
 | |
|     def test_trace_dict_proxy_keys(self):
 | |
|         class ModWithDictArg(torch.nn.Module):
 | |
|             def forward(self, d : Dict[torch.Tensor, torch.Tensor]):
 | |
|                 return d[42]
 | |
| 
 | |
|         class CallsModWithDict(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.m = ModWithDictArg()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.m({x: x})
 | |
| 
 | |
|         class MyTracer(torch.fx.Tracer):
 | |
|             def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
 | |
|                 return isinstance(m, ModWithDictArg)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, 'cannot contain a Node'):
 | |
|             traced_graph = MyTracer().trace(CallsModWithDict())
 | |
| 
 | |
|     def test_module_deepcopy_edit_nodes(self):
 | |
|         class Foo(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.relu(x)
 | |
| 
 | |
|         traced1 = symbolic_trace(Foo())
 | |
|         copied = copy.deepcopy(traced1)
 | |
| 
 | |
|         for node in copied.graph.nodes:
 | |
|             if node.target == torch.relu:
 | |
|                 node.target = torch.neg
 | |
| 
 | |
|         copied.recompile()
 | |
|         traced1.recompile()
 | |
| 
 | |
|         x = torch.randn(15, 15)
 | |
|         torch.testing.assert_allclose(traced1(x), torch.relu(x))
 | |
|         torch.testing.assert_allclose(copied(x), torch.neg(x))
 | |
| 
 | |
|     def test_direct_param_use(self):
 | |
|         class TransposeTest(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.b = torch.nn.Parameter(torch.rand(4, 3))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.b
 | |
| 
 | |
|         class Foo(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.a = TransposeTest()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.a.b, self.a.b.t(), self.a.b.view(12)
 | |
| 
 | |
|         traced = torch.fx.symbolic_trace(Foo())
 | |
|         assert(all('constant' not in node.target for node in traced.graph.nodes))
 | |
| 
 | |
|     def test_single_default_arg(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
| 
 | |
|             def forward(self, y=1):
 | |
|                 return y
 | |
| 
 | |
|         m = M()
 | |
|         self.checkGraphModule(m, ())
 | |
|         self.checkGraphModule(m, (3,))
 | |
| 
 | |
|     def test_multiple_default_args(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
| 
 | |
|             def forward(self, y=1, z=2):
 | |
|                 return y + z
 | |
| 
 | |
|         m = M()
 | |
|         self.checkGraphModule(m, ())
 | |
|         self.checkGraphModule(m, (3,))
 | |
|         self.checkGraphModule(m, (3, 4))
 | |
| 
 | |
|     def test_regular_and_default_args(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
| 
 | |
|             def forward(self, x, y=1):
 | |
|                 return x + y
 | |
| 
 | |
|         m = M()
 | |
|         self.checkGraphModule(m, (2,))
 | |
|         self.checkGraphModule(m, (2, 3))
 | |
| 
 | |
|     def test_string_literal_return(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
| 
 | |
|             def forward(self):
 | |
|                 return "foo"
 | |
| 
 | |
|         m = M()
 | |
|         self.checkGraphModule(m, ())
 | |
| 
 | |
|     def test_namedtuple_return_qualname(self):
 | |
|         class NamedTupReturn(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return MyNamedTup(x, x)
 | |
| 
 | |
|         traced = symbolic_trace(NamedTupReturn())
 | |
|         input = torch.rand(3, 4)
 | |
|         self.assertEqual(traced(input), MyNamedTup(input, input))
 | |
| 
 | |
|     def test_update_args_kwargs_yells_at_you(self):
 | |
|         symtraced = symbolic_trace(SimpleTest())
 | |
|         node = next(iter(symtraced.graph.nodes))
 | |
|         with self.assertRaisesRegex(AttributeError, '__update_args_kwargs'):
 | |
|             node.__update_args_kwargs((), {})
 | |
| 
 | |
|     def test_torchbind_class_attribute_in_fx(self):
 | |
|         if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS:
 | |
|             self.skipTest("torch.classes._TorchScriptTesting._StackString is registered, skipping")
 | |
| 
 | |
|         class FooBar1234(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(FooBar1234, self).__init__()
 | |
|                 self.f = torch.classes._TorchScriptTesting._StackString(["3", "4"])
 | |
| 
 | |
|             def forward(self):
 | |
|                 return self.f.top()
 | |
| 
 | |
|         m = FooBar1234()
 | |
|         self.checkGraphModule(m, ())
 | |
| 
 | |
|     def test_torchbind_class_attribute_in_fx_tensor_arg(self):
 | |
|         if TEST_WITH_ROCM or IS_FBCODE or IS_WINDOWS or IS_MACOS:
 | |
|             self.skipTest("torch.classes._TorchScriptTesting._ReLUClass is registered, skipping")
 | |
| 
 | |
|         class FooBar2341(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(FooBar2341, self).__init__()
 | |
|                 self.f = torch.classes._TorchScriptTesting._ReLUClass()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.f.run(x)
 | |
| 
 | |
|         m = FooBar2341()
 | |
| 
 | |
|         traced = symbolic_trace(m)
 | |
|         input = torch.randn(3, 4)
 | |
|         self.assertEqual(traced(input), m(input))
 | |
| 
 | |
|         self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
 | |
| 
 | |
|     def test_script_method_trace(self):
 | |
|         class Scripted(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.relu(x)
 | |
| 
 | |
|         class Holder(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.s = torch.jit.script(Scripted())
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.s(x)
 | |
| 
 | |
|         h = Holder()
 | |
|         traced = symbolic_trace(h)
 | |
|         input = torch.randn(3, 4)
 | |
|         self.assertEqual(traced(input), h(input))
 | |
| 
 | |
|         self.assertTrue(any(n.op == 'call_method' for n in traced.graph.nodes))
 | |
| 
 | |
|     def test_namedtuple_return_trace(self):
 | |
|         class NamedTupReturn(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return Pair(x, x)
 | |
| 
 | |
|         traced = symbolic_trace(NamedTupReturn())
 | |
|         input = torch.rand(3, 4)
 | |
|         self.assertEqual(traced(input), Pair(input, input))
 | |
| 
 | |
|     def test_return_type_exists(self):
 | |
|         class ReturnTypeModule(torch.nn.Module):
 | |
|             def other(self, x: List[str]) -> List[str]:
 | |
|                 return x
 | |
| 
 | |
|             def forward(self, x: List[str]) -> List[str]:
 | |
|                 return self.other(x)
 | |
| 
 | |
|         traced = symbolic_trace(ReturnTypeModule())
 | |
|         self.assertIn("-> typing_List[str]", traced._code)
 | |
|         scripted = torch.jit.script(traced)
 | |
|         self.assertIn("-> List[str]", scripted.code)
 | |
| 
 | |
|     def getitem_inner(self):
 | |
|         class GetItemBase(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.register_buffer('pe', torch.randn(8, 8))
 | |
| 
 | |
|         class GetItem1(GetItemBase):
 | |
|             def forward(self, x):
 | |
|                 return self.pe[:, :x.size(0)]
 | |
| 
 | |
|         class GetItem2(GetItemBase):
 | |
|             def forward(self, x):
 | |
|                 return self.pe[x.size(0)]
 | |
| 
 | |
|         class GetItem3(GetItemBase):
 | |
|             def forward(self, x):
 | |
|                 return self.pe[4]  # fx creates `self._tensor_constant0` here
 | |
| 
 | |
|         self.checkGraphModule(GetItem1(), [torch.zeros(4)])
 | |
|         self.checkGraphModule(GetItem2(), [torch.zeros(4)])
 | |
|         self.checkGraphModule(GetItem3(), [torch.zeros(4)])
 | |
| 
 | |
|     @unittest.skipUnless(os.environ.get("FX_PATCH_GETITEM") == "1",
 | |
|                          "Will be checked in test_getitem_subproc")
 | |
|     def test_getitem(self):
 | |
|         self.getitem_inner()
 | |
| 
 | |
|     def test_getitem_subproc(self):
 | |
|         # need to run this test in a subproc to work around:
 | |
|         #   https://github.com/pytorch/pytorch/issues/50710
 | |
|         proc = Process(target=run_getitem_target)
 | |
|         proc.start()
 | |
|         proc.join()
 | |
|         self.assertEqual(proc.exitcode, 0)
 | |
| 
 | |
| 
 | |
|     def test_user_friendly_call_provenance_with_function(self):
 | |
|         def fn(x):
 | |
|             return wrapper_fn(x)
 | |
| 
 | |
|         traced = torch.fx.symbolic_trace(fn)
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
 | |
|                                     "being compiled since it was called"
 | |
|                                     " from 'fn.forward'"):
 | |
|             scripted = torch.jit.script(traced)
 | |
| 
 | |
|     def test_user_friendly_call_provenance_with_module(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return wrapper_fn(x)
 | |
| 
 | |
|         traced = torch.fx.symbolic_trace(M())
 | |
| 
 | |
|         with self.assertRaisesRegex(RuntimeError, "'wrapper_fn' is "
 | |
|                                     "being compiled since it was called"
 | |
|                                     " from 'M.forward'"):
 | |
|             scripted = torch.jit.script(traced)
 | |
| 
 | |
|     def test_snake_case(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(M, self).__init__()
 | |
|                 self.activations = torch.nn.ModuleDict([
 | |
|                     ["snake_case", torch.nn.ReLU()],
 | |
|                     ["PascalCase", torch.nn.LeakyReLU()],
 | |
|                     ["ALL_CAPS", torch.nn.PReLU()]
 | |
|                 ])
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 a = self.activations["snake_case"](x)
 | |
|                 b = self.activations["PascalCase"](x)
 | |
|                 c = self.activations["ALL_CAPS"](x)
 | |
|                 return a, b, c
 | |
| 
 | |
|         traced = symbolic_trace(M())
 | |
| 
 | |
|         check = [
 | |
|             ("activations_snake_case", "activations.snake_case"),
 | |
|             ("activations_pascal_case", "activations.PascalCase"),
 | |
|             ("activations_all_caps", "activations.ALL_CAPS")
 | |
|         ]
 | |
| 
 | |
|         i = 0
 | |
|         for node in traced.graph.nodes:
 | |
|             if node.op == "placeholder" or node.op == "output":
 | |
|                 continue
 | |
|             name = check[i][0]
 | |
|             target = check[i][1]
 | |
|             self.assertEqual(name, node.name)
 | |
|             self.assertEqual(target, node.target)
 | |
|             i += 1
 | |
|         self.assertEqual(i, 3)
 | |
| 
 | |
|     def test_no_mutation(self):
 | |
|         from torch.fx.immutable_collections import immutable_list
 | |
|         x = immutable_list([3, 4])
 | |
|         with self.assertRaisesRegex(NotImplementedError, "new_args"):
 | |
|             x[0] = 4
 | |
| 
 | |
|     def test_partial_trace(self):
 | |
|         class Foo(torch.nn.Module):
 | |
|             def forward(self, x, y):
 | |
|                 if y:
 | |
|                     return 2 * x
 | |
|                 else:
 | |
|                     return x
 | |
|         mod = Foo()
 | |
|         mod_true = symbolic_trace(mod, concrete_args={'y': True})
 | |
|         mod_false = symbolic_trace(mod, concrete_args={'y': False})
 | |
|         self.assertEqual(mod_true(3, True), 6)
 | |
|         print(mod_true.code)
 | |
|         assert(any([i.target == torch._assert for i in mod_true.graph.nodes]))
 | |
|         with self.assertRaises(AssertionError):
 | |
|             mod_true(3, False)
 | |
|         self.assertEqual(mod_false(3, False), 3)
 | |
|         with self.assertRaises(AssertionError):
 | |
|             mod_false(3, True)
 | |
| 
 | |
|         def f_higher(a, f):
 | |
|             return f(a)
 | |
| 
 | |
|         nf = symbolic_trace(f_higher, concrete_args={'f': lambda x: x * 2})
 | |
|         self.assertEqual(nf(3, lambda x: x * 2), 6)
 | |
| 
 | |
|     def test_custom_traceback_raised_when_exception_source_is_graphmodule(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(M, self).__init__()
 | |
|                 self.W = torch.nn.Parameter(torch.randn(5))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return torch.dot(self.W, x)
 | |
| 
 | |
|         traced = torch.fx.symbolic_trace(M())
 | |
| 
 | |
|         out = [n for n in traced.graph.nodes if n.op == "output"][-1]
 | |
|         with traced.graph.inserting_before(out):
 | |
|             relu_out = traced.graph.call_method(method_name='relu',
 | |
|                                                 args=(out.args[0],))
 | |
|         out.args = (relu_out,)
 | |
| 
 | |
|         traced.recompile()
 | |
| 
 | |
|         with self.capture_stderr() as captured:
 | |
|             with self.assertRaises(TypeError):
 | |
|                 traced(5)
 | |
| 
 | |
|         self.assertRegex(captured[0],
 | |
|                          r"Call using an FX-traced Module, line .* of the "
 | |
|                          r"traced Module's generated forward function:")
 | |
| 
 | |
|     def test_custom_traceback_not_raised_when_exception_source_is_submodule(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(3, 4)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.linear(x)
 | |
| 
 | |
|         traced = torch.fx.symbolic_trace(M())
 | |
| 
 | |
|         # Do not change this to `capture_stderr` or another context
 | |
|         # manager without ensuring that the output is as expected
 | |
|         try:
 | |
|             traced(torch.rand(5, 5))
 | |
|         except RuntimeError:
 | |
|             captured = traceback.format_exc()
 | |
| 
 | |
|         self.assertNotRegex(captured,
 | |
|                             r"Call using an FX-traced Module, line .* of the "
 | |
|                             r"traced Module's generated forward function:")
 | |
| 
 | |
|     def test_graph_module_replicate_for_dp(self):
 | |
|         class Foo(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return torch.relu(x)
 | |
| 
 | |
|         gm = torch.fx.symbolic_trace(Foo())
 | |
| 
 | |
|         x = torch.randn(5, 3)
 | |
|         out = gm(x)
 | |
| 
 | |
|         replica = gm._replicate_for_data_parallel()
 | |
|         out_replica = replica(x)
 | |
| 
 | |
|         torch.testing.assert_allclose(out_replica, out)
 | |
| 
 | |
|     def test_ast_rewriter_rewrites_assert(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, x: torch.Tensor, y: int, z: int):
 | |
|                 assert y == z
 | |
|                 return torch.add(x, x)
 | |
| 
 | |
|         ast_rewriter = RewritingTracer()
 | |
|         graph = ast_rewriter.trace(M())
 | |
|         traced = GraphModule(ast_rewriter.root, graph, "gm")
 | |
| 
 | |
|         traced.graph.lint()
 | |
| 
 | |
|     def test_ast_rewriter_rewrites_assert_with_message(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, x: torch.Tensor, y: int, z: int):
 | |
|                 assert y == z, "msg"
 | |
|                 return torch.add(x, x)
 | |
| 
 | |
|         ast_rewriter = RewritingTracer()
 | |
|         graph = ast_rewriter.trace(M())
 | |
|         traced = GraphModule(ast_rewriter.root, graph, "gm")
 | |
| 
 | |
|         traced.graph.lint()
 | |
| 
 | |
|     def test_throw_out_variant(self):
 | |
|         def foo(x):
 | |
|             y = torch.rand_like(x)
 | |
|             torch.sigmoid(x, out=y)
 | |
|             return y
 | |
| 
 | |
|         class MyTracer(torch.fx.Tracer):
 | |
|             check_mutable_operations = True
 | |
| 
 | |
|         tracer = MyTracer()
 | |
|         with self.assertRaisesRegex(RuntimeError, 'mutable operation aten::sigmoid.out'):
 | |
|             traced_graph = tracer.trace(foo)
 | |
| 
 | |
|     def test_ast_rewriter_reassigns_submodules(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.bn = torch.nn.BatchNorm2d(100)
 | |
| 
 | |
|             def forward(self, x: torch.Tensor):
 | |
|                 return torch.add(x, x)
 | |
| 
 | |
|         ast_rewriter = RewritingTracer()
 | |
|         graph = ast_rewriter.trace(M())
 | |
|         traced = GraphModule(ast_rewriter.root, graph, "gm")
 | |
| 
 | |
|         traced.graph.lint()
 | |
| 
 | |
|     def test_ast_rewriter_wrap(self):
 | |
|         self.assertEqual(3 + 4 + 5, a_lifted_leaf((3, 4), 5))
 | |
| 
 | |
|         def to_trace(y):
 | |
|             return (
 | |
|                 a_lifted_leaf((4, y), 3)
 | |
|                 + a_lifted_leaf((3, 4), 5)
 | |
|                 + a_lifted_leaf((y, y), y)
 | |
|             )
 | |
| 
 | |
|         ast_rewriter = RewritingTracer()
 | |
|         graph = ast_rewriter.trace(to_trace)
 | |
|         traced = GraphModule(ast_rewriter.root, graph, "gm")
 | |
| 
 | |
|         self.assertIn("a_lifted_leaf", traced.code)
 | |
|         self.assertEqual(27, traced(2))
 | |
|         self.assertIs(a_lifted_leaf, real_a_lifed_leaf)
 | |
| 
 | |
|     def test_ast_rewriter_wrap_fn_directly(self):
 | |
|         self.assertEqual(3 + 4 + 5, a_lifted_leaf2((3, 4), 5))
 | |
| 
 | |
|         def to_trace(y):
 | |
|             return (
 | |
|                 a_lifted_leaf2((4, y), 3)
 | |
|                 + a_lifted_leaf2((3, 4), 5)
 | |
|                 + a_lifted_leaf2((y, y), y)
 | |
|             )
 | |
| 
 | |
|         ast_rewriter = RewritingTracer()
 | |
|         graph = ast_rewriter.trace(to_trace)
 | |
|         traced = GraphModule(ast_rewriter.root, graph, "gm")
 | |
| 
 | |
|         self.assertIn("a_lifted_leaf2", traced.code)
 | |
|         self.assertEqual(27, traced(2))
 | |
|         self.assertIs(a_lifted_leaf2, real_a_lifed_leaf2)
 | |
| 
 | |
|     def test_profiler_ranges_side_effect(self):
 | |
|         g = torch.fx.Graph()
 | |
|         handle = g.call_function(torch.ops.profiler._record_function_enter, ('test_range',))
 | |
|         g.call_function(torch.ops.profiler._record_function_exit, (handle,))
 | |
|         g.output(None)
 | |
| 
 | |
|         found_targets = {}
 | |
|         for node in g.nodes:
 | |
|             if node.op == 'call_function':
 | |
|                 found_targets.setdefault(node.target)
 | |
|         self.assertEqual(
 | |
|             found_targets.keys(), [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit])
 | |
| 
 | |
|         g.eliminate_dead_code()
 | |
|         found_targets = {}
 | |
|         for node in g.nodes:
 | |
|             if node.op == 'call_function':
 | |
|                 found_targets.setdefault(node.target)
 | |
|         self.assertEqual(
 | |
|             found_targets.keys(), [torch.ops.profiler._record_function_enter, torch.ops.profiler._record_function_exit])
 | |
| 
 | |
|     def test_ast_rewriter_wrapped_via_decorator(self):
 | |
|         class F(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 return wrapped_via_decorator(x)
 | |
| 
 | |
|         ast_rewriter = RewritingTracer()
 | |
|         graph = ast_rewriter.trace(F())
 | |
|         traced = GraphModule(ast_rewriter.root, graph, "gm")
 | |
| 
 | |
|         self.assertIn("wrapped_via_decorator", traced.code)
 | |
|         self.assertEqual(traced(0), 1)
 | |
|         self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
 | |
|         self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
 | |
| 
 | |
|     def test_ast_rewriter_wrapped_via_decorator_and_transformed(self):
 | |
|         self.assertEqual(wrapped_via_decorator(0), 1)
 | |
| 
 | |
|         def to_trace(y):
 | |
|             return wrapped_via_decorator(y)
 | |
| 
 | |
|         ast_rewriter = RewritingTracer()
 | |
|         graph = ast_rewriter.trace(to_trace)
 | |
|         traced = GraphModule(ast_rewriter.root, graph, "gm")
 | |
| 
 | |
|         self.assertIn("wrapped_via_decorator", traced.code)
 | |
|         self.assertEqual(traced(0), 1)
 | |
|         self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
 | |
|         self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
 | |
| 
 | |
|         transformed = torch.fx.Transformer(traced).transform()
 | |
|         self.assertIn("wrapped_via_decorator", transformed.code)
 | |
|         self.assertEqual(transformed(0), 1)
 | |
|         self.assertIs(wrapped_via_decorator, real_wrapped_via_decorator)
 | |
|         self.assertFalse(hasattr(wrapped_via_decorator, "__fx_already_patched"))
 | |
| 
 | |
|     def test_ast_rewriter_wrap_with_submodule(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(M, self).__init__()
 | |
|                 self.batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
 | |
| 
 | |
|             def forward(self, x: torch.Tensor):
 | |
|                 return wrapped_with_submodule(x, self.batchnorm1d)
 | |
| 
 | |
|         ast_rewriter = RewritingTracer()
 | |
|         graph = ast_rewriter.trace(M())
 | |
|         traced = GraphModule(ast_rewriter.root, graph, "gm")
 | |
| 
 | |
|         self.assertIn("wrapped_with_submodule", traced.code)
 | |
| 
 | |
|         input = torch.rand(3, 2)
 | |
|         ref_batchnorm1d = torch.nn.BatchNorm1d(2, affine=False)
 | |
|         self.assertEqual(ref_batchnorm1d(input), traced(input))
 | |
| 
 | |
|     def test_submodule_manipulation_API(self):
 | |
|         class C(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(C, self).__init__()
 | |
|                 self.conv = torch.nn.Conv2d(16, 33, 3, stride=2)
 | |
|                 self.param = torch.nn.Parameter(torch.rand(2, 3))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.conv(torch.cat([self.param, x]))
 | |
| 
 | |
|         class B(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(B, self).__init__()
 | |
|                 self.linear = torch.nn.Linear(100, 200)
 | |
|                 self.register_buffer("buf", torch.randn(2, 3))
 | |
|                 self.net_c = C()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.linear(torch.cat([self.buf, self.net_c(x)]))
 | |
| 
 | |
|         class A(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(A, self).__init__()
 | |
|                 self.net_b = B()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(2, 3))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.net_b(x) + self.param
 | |
| 
 | |
|         a = symbolic_trace(A())
 | |
| 
 | |
|         a.add_submodule("net_b.net_c.dropout", torch.nn.Dropout(p=0.2))
 | |
| 
 | |
|         conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"][-1]
 | |
|         with a.graph.inserting_before(conv):
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 dropout = a.graph.call_module(module_name="net_b.net_c.dropout",
 | |
|                                               args=conv.args)
 | |
|                 self.assertEqual(len(w), 0)
 | |
| 
 | |
|         conv.replace_all_uses_with(dropout)
 | |
|         a.graph.erase_node(conv)
 | |
|         a.recompile()
 | |
| 
 | |
|         def module_exists(gm: GraphModule, path: str) -> bool:
 | |
|             return any(path == name for name, _ in gm.named_modules())
 | |
| 
 | |
|         def parameter_exists(gm: GraphModule, path: str) -> bool:
 | |
|             return (any(path == name for name, _ in gm.named_parameters())
 | |
|                     and any(path == name for name in gm.state_dict().keys()))
 | |
| 
 | |
|         def buffer_exists(gm: GraphModule, path: str) -> bool:
 | |
|             return (any(path == name for name, _ in gm.named_buffers())
 | |
|                     and any(path == name for name in gm.state_dict().keys()))
 | |
| 
 | |
|         # Test that we added the "dropout" submodule
 | |
|         self.assertTrue(module_exists(a, "net_b.net_c.dropout"))
 | |
| 
 | |
|         # Test `get_submodule` with an added submodule
 | |
|         self.assertIsNotNone(a.get_submodule("net_b.net_c.dropout"))
 | |
| 
 | |
|         # Test that the "conv" submodule is still there
 | |
|         self.assertTrue(module_exists(a, "net_b.net_c.conv"))
 | |
| 
 | |
|         # Test `get_submodule` with an original module
 | |
|         self.assertIsNotNone(a.get_submodule("net_b.net_c.conv"))
 | |
| 
 | |
|         # Test that the "conv" node is NOT still there
 | |
|         conv = [n for n in a.graph.nodes if n.target == "net_b.net_c.conv"]
 | |
|         self.assertEqual(conv, [])
 | |
| 
 | |
|         a.delete_submodule("net_b.net_c.conv")
 | |
| 
 | |
|         # Test that the "conv" submodule is now gone
 | |
|         self.assertFalse(module_exists(a, "net_b.net_c.conv"))
 | |
| 
 | |
|         # Test `get_submodule` with a deleted submodule
 | |
|         with self.assertRaisesRegex(AttributeError, "has no attribute "
 | |
|                                     "`conv`"):
 | |
|             self.assertIsNone(a.get_submodule("net_b.net_c.conv"))
 | |
| 
 | |
|         # Test `get_attr` warnings
 | |
|         cat = [n for n in a.graph.nodes if n.target == torch.cat][-1]
 | |
| 
 | |
|         with a.graph.inserting_before(cat):
 | |
| 
 | |
|             with warnings.catch_warnings(record=True) as w:
 | |
|                 param = a.graph.get_attr(qualified_name="net_b.net_c.param")
 | |
|                 self.assertEqual(len(w), 0)
 | |
| 
 | |
|             with self.assertWarnsRegex(UserWarning, "Attempted to "
 | |
|                                        "insert a get_attr Node with no "
 | |
|                                        "underlying reference in the "
 | |
|                                        "owning GraphModule"):
 | |
|                 bad_param = a.graph.get_attr(qualified_name="net_b.param")
 | |
|                 a.graph.erase_node(bad_param)
 | |
| 
 | |
|         cat.args = (*cat.args, param)
 | |
| 
 | |
|         a.recompile()
 | |
| 
 | |
|         a.graph.lint()
 | |
| 
 | |
|         # Test `get_parameter`
 | |
|         a.get_parameter("net_b.net_c.param")
 | |
|         with self.assertRaisesRegex(AttributeError, "is not an "
 | |
|                                     "nn.Parameter"):
 | |
|             a.get_parameter("net_b.buf")
 | |
|         with self.assertRaisesRegex(AttributeError, "has no attribute "
 | |
|                                     "`param`"):
 | |
|             a.get_parameter("net_b.param")
 | |
| 
 | |
|         # Test `get_buffer`
 | |
|         a.get_buffer("net_b.buf")
 | |
|         with self.assertRaisesRegex(AttributeError, "is not a "
 | |
|                                     "buffer"):
 | |
|             a.get_buffer("net_b.net_c.param")
 | |
|         with self.assertRaisesRegex(AttributeError, "has no attribute "
 | |
|                                     "`buf`"):
 | |
|             a.get_buffer("net_b.net_c.buf")
 | |
| 
 | |
|         # Test non-nested attributes
 | |
|         a.get_submodule("")
 | |
|         a.get_parameter("param")
 | |
| 
 | |
|         # Insert some unused submodules
 | |
|         a.add_submodule("net_b.embedding", torch.nn.Embedding(10, 3))
 | |
|         a.add_submodule("net_b.net_c.embedding", torch.nn.Embedding(10, 3))
 | |
|         a.add_submodule("net_b.net_c.rnn", torch.nn.RNN(10, 20, 2))
 | |
|         a.add_submodule("batch_norm_2d", torch.nn.BatchNorm2d(100))
 | |
| 
 | |
|         # Garbage collection
 | |
|         a.delete_all_unused_submodules()
 | |
| 
 | |
|         # Test that all the unused submodules are gone
 | |
|         self.assertFalse(module_exists(a, "net_b.embedding"))
 | |
|         self.assertFalse(module_exists(a, "net_b.net_c.embedding"))
 | |
|         self.assertFalse(module_exists(a, "net_b.net_c.rnn"))
 | |
|         self.assertFalse(module_exists(a, "batch_norm_2d"))
 | |
| 
 | |
|         # Test that we didn't delete any unused Parameters or buffers
 | |
|         self.assertTrue(parameter_exists(a, "net_b.net_c.param"))
 | |
|         self.assertTrue(buffer_exists(a, "net_b.buf"))
 | |
| 
 | |
|         a.graph.lint()
 | |
| 
 | |
|     def test_delete_unused_submodules_leaf(self):
 | |
|         class SubModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(10, 10)
 | |
|                 self.relu = torch.nn.ReLU()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.linear(x)
 | |
|                 x = self.relu(x)
 | |
|                 return x
 | |
| 
 | |
|         class Model(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.submod = SubModule()
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.submod(x)
 | |
|                 return x
 | |
| 
 | |
|         model = Model()
 | |
| 
 | |
|         class MyCustomTracer(torch.fx.Tracer):
 | |
|             def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:
 | |
|                 return module_qualified_name == "submod"
 | |
| 
 | |
|         inputs = torch.randn(1, 10)
 | |
|         traced_graph = MyCustomTracer().trace(model)
 | |
|         gm2 = torch.fx.GraphModule(model, traced_graph)
 | |
|         gm2.delete_all_unused_submodules()
 | |
|         torch.testing.assert_allclose(gm2(inputs), model(inputs))
 | |
| 
 | |
|     def test_tracing_graphmodules_as_leaf_submodules(self):
 | |
|         class A(torch.nn.Module):
 | |
|             def forward(self, t):
 | |
|                 return t + t
 | |
| 
 | |
|         class B(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super(type(self), self).__init__()
 | |
|                 self.calling = False
 | |
|                 self.called = False
 | |
| 
 | |
|             def forward(self, t):
 | |
|                 if self.calling:
 | |
|                     return t - t
 | |
|                 else:
 | |
|                     return t + t
 | |
| 
 | |
|             def __call__(self, *args):
 | |
|                 self.called = True
 | |
|                 self.calling = True
 | |
|                 return super(type(self), self).__call__(*args)
 | |
|                 self.calling = False
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self, a, b):
 | |
|                 super().__init__()
 | |
|                 self.a = a
 | |
|                 self.b = b
 | |
| 
 | |
|             def forward(self, t):
 | |
|                 x = self.a(t)
 | |
|                 y = self.b(t)
 | |
|                 return x + y
 | |
| 
 | |
|         class LeafTracer(Tracer):
 | |
|             def is_leaf_module(self, module, name):
 | |
|                 return True
 | |
| 
 | |
|         class LeafTracerNotB(Tracer):
 | |
|             def is_leaf_module(self, module, name):
 | |
|                 return False if "b" in name else True
 | |
| 
 | |
|         # Recompile calls added "for fun", since they
 | |
|         # chain __call__ wrappers.
 | |
| 
 | |
|         #
 | |
|         # Test: B as a regular, non-leaf module
 | |
|         #
 | |
|         a = symbolic_trace(A())
 | |
|         a.recompile()
 | |
|         m = M(a, B())
 | |
|         graph = LeafTracerNotB().trace(m)
 | |
|         gm = GraphModule(m, graph)
 | |
|         gm.recompile()
 | |
| 
 | |
|         # Test graphmodule/submodule a is not inlined.
 | |
|         self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
 | |
|         match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
 | |
|         self.assertTrue(len(match) == 1)
 | |
| 
 | |
|         # Test submodule b is not treated as leaf.
 | |
|         self.assertFalse(hasattr(gm, "b"))
 | |
| 
 | |
|         # Test assert custom __call__ on submodule b was honored.
 | |
|         match = [
 | |
|             n
 | |
|             for n in gm.graph.nodes
 | |
|             if n.op == "call_function" and n.target == operator.sub
 | |
|         ]
 | |
|         self.assertTrue(len(match) == 1)
 | |
| 
 | |
|         #
 | |
|         # Test: B as a regular, leaf module
 | |
|         # symbolic_trace should only patch torch.nn.Module.__call__,
 | |
|         # which means B.__call__ should still execute
 | |
|         #
 | |
|         a = symbolic_trace(A())
 | |
|         a.recompile()
 | |
|         b = B()
 | |
|         m = M(a, b)
 | |
|         graph = LeafTracer().trace(m)
 | |
|         gm = GraphModule(m, graph)
 | |
|         gm.recompile()
 | |
| 
 | |
|         # Test graphmodule/submodule a is not inlined.
 | |
|         self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
 | |
|         match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
 | |
|         self.assertTrue(len(match) == 1)
 | |
| 
 | |
|         # Test submodule b is leaf:
 | |
|         self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
 | |
|         match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
 | |
|         self.assertTrue(len(match) == 1)
 | |
| 
 | |
|         # Test b.__call__ was run
 | |
|         self.assertTrue(b.called)
 | |
|         self.assertTrue(gm.get_submodule("b").called)
 | |
| 
 | |
|         #
 | |
|         # Test: B as GraphModule leaf
 | |
|         # __call__ not honored since symbolic_trace directly invokes forward()
 | |
|         #
 | |
|         a = symbolic_trace(A())
 | |
|         a.recompile()
 | |
|         b = symbolic_trace(B())
 | |
|         b.recompile()
 | |
|         m = M(a, b)
 | |
|         graph = LeafTracer().trace(m)
 | |
|         gm = GraphModule(m, graph)
 | |
|         gm.recompile()
 | |
| 
 | |
|         self.assertTrue(isinstance(gm.get_submodule("a"), GraphModule))
 | |
|         match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "a"]
 | |
|         self.assertTrue(len(match) == 1)
 | |
| 
 | |
|         self.assertTrue(isinstance(gm.get_submodule("b"), torch.nn.Module))
 | |
|         match = [n for n in gm.graph.nodes if n.op == "call_module" and n.target == "b"]
 | |
|         self.assertTrue(len(match) == 1)
 | |
| 
 | |
|     def _test_graph_module_init_buffer_param_copied(self, use_dict_init: bool):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
|                 self.register_buffer("my_buff", torch.rand(3, 4))
 | |
|                 self.register_parameter(
 | |
|                     "my_param", torch.nn.Parameter(torch.rand(3, 4))
 | |
|                 )
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return x + self.my_buff + self.my_param
 | |
| 
 | |
|         mod = MyModule()
 | |
|         mod_traced = symbolic_trace(mod)
 | |
| 
 | |
|         # Create new GraphModule based on original, either w/ dict or root module.
 | |
|         orig_buff = mod_traced.get_buffer("my_buff")
 | |
|         orig_param = mod_traced.get_parameter("my_param")
 | |
|         mod_traced_new = GraphModule(
 | |
|             {"my_buff": orig_buff, "my_param": orig_param} if use_dict_init else mod,
 | |
|             mod_traced.graph,
 | |
|         )
 | |
| 
 | |
|         # Check that both my_buff and my_param are found and the same.
 | |
|         try:
 | |
|             new_buff = mod_traced_new.get_buffer("my_buff")
 | |
|         except Exception:
 | |
|             self.fail("Did not find my_buff")
 | |
|         self.assertEqual(orig_buff, new_buff)
 | |
| 
 | |
|         try:
 | |
|             new_param = mod_traced_new.get_parameter("my_param")
 | |
|         except Exception:
 | |
|             self.fail("Did not find my_param")
 | |
|         self.assertEqual(orig_param, new_param)
 | |
| 
 | |
|         x = torch.rand(3, 4)
 | |
|         orig_out = mod_traced(x)
 | |
|         submodules_out = mod_traced_new(x)
 | |
| 
 | |
|         self.assertEqual(orig_out, submodules_out)
 | |
| 
 | |
|     def test_graph_module_init_buffer_param_copied_dict_init(self):
 | |
|         self._test_graph_module_init_buffer_param_copied(use_dict_init=True)
 | |
| 
 | |
|     def test_graph_module_init_buffer_param_copied_mod_init(self):
 | |
|         self._test_graph_module_init_buffer_param_copied(use_dict_init=False)
 | |
| 
 | |
|     def test_annotations_with_no_forward_references(self):
 | |
|         class A:
 | |
|             def __call__(self, x: torch.Tensor):
 | |
|                 return torch.add(x, x)
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
 | |
|                 return a(x)
 | |
| 
 | |
|         self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
 | |
| 
 | |
|     def test_annotations_with_forward_references(self):
 | |
|         class A:
 | |
|             def __call__(self, x: torch.Tensor):
 | |
|                 return torch.add(x, x)
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, x: 'torch.Tensor', a: 'A') -> 'torch.Tensor':
 | |
|                 return a(x)
 | |
| 
 | |
|         self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
 | |
| 
 | |
|     def test_annotations_with_non_torch_reference_and_no_internal_forward_references(self):
 | |
|         class A:
 | |
|             def __call__(self, x: torch.Tensor):
 | |
|                 return torch.add(x, x)
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, x: List[torch.Tensor], a: A) -> torch.Tensor:
 | |
|                 return a(x[0])
 | |
| 
 | |
|         self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
 | |
| 
 | |
|     def test_annotations_with_non_torch_reference_and_internal_forward_references(self):
 | |
|         class A:
 | |
|             def __call__(self, x: torch.Tensor):
 | |
|                 return torch.add(x, x)
 | |
| 
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, x: List['torch.Tensor'], a: A) -> 'torch.Tensor':
 | |
|                 return a(x)[0]
 | |
| 
 | |
|         self.checkGraphModule(M(), (torch.rand(2, 3), A()), kwargs=None)
 | |
| 
 | |
|     @unittest.skipIf(sys.version_info < (3, 7), "`__future__` feature "
 | |
|                      "`annotations` is not defined in Python <3.7")
 | |
|     def test_annotation_with_future(self):
 | |
|         try:
 | |
|             import fx.test_future    # noqa: F401
 | |
|         finally:
 | |
|             del sys.modules["__future__"]
 | |
| 
 | |
|     def test_annotations_empty_tuple(self):
 | |
|         class Foo(torch.nn.Module):
 | |
|             def forward(self, x: Tuple[()], y: Tuple[str, Tuple[()]]):
 | |
|                 return "foo"
 | |
| 
 | |
|         traced = torch.fx.symbolic_trace(Foo())
 | |
| 
 | |
|         x = ()
 | |
|         y = ("bar", ())
 | |
| 
 | |
|         traced(x, y)
 | |
| 
 | |
|         FileCheck().check("_Tuple[()]")   \
 | |
|                    .check("typing_Tuple[str,typing_Tuple[()]]") \
 | |
|                    .run(traced.code)
 | |
| 
 | |
|         scripted = torch.jit.script(traced)
 | |
| 
 | |
|         scripted(x, y)
 | |
| 
 | |
|         FileCheck().check("Tuple[()]")   \
 | |
|             .check("Tuple[str, Tuple[()]]")    \
 | |
|             .run(scripted.code)
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     def test_cpatcher(self):
 | |
| 
 | |
|         cnt = 0
 | |
| 
 | |
|         def patched_impl(to_patch, args, kwargs):
 | |
|             nonlocal cnt
 | |
|             cnt += 1
 | |
|             return to_patch(*args, **kwargs)
 | |
| 
 | |
|         c_patch_enabled = True
 | |
| 
 | |
|         def patched_in(to_patch, args, kwargs):
 | |
|             nonlocal c_patch_enabled
 | |
|             try:
 | |
|                 c_patch_enabled = False
 | |
|                 r = patched_impl(to_patch, args, kwargs)
 | |
|             finally:
 | |
|                 c_patch_enabled = True
 | |
|             return r
 | |
| 
 | |
| 
 | |
|         def trace_func(frame, action, arg):
 | |
|             if action == 'c_call':
 | |
|                 if c_patch_enabled:
 | |
|                     torch._C._fx.patch_function(arg, patched_in)
 | |
| 
 | |
| 
 | |
|         import torch
 | |
|         rn = torchvision_models.resnet18()
 | |
| 
 | |
|         try:
 | |
|             sys.setprofile(trace_func)
 | |
|             rn(torch.rand(1, 3, 224, 224))
 | |
|             print("testing print patch")
 | |
|         finally:
 | |
|             sys.setprofile(None)
 | |
|         assert(cnt != 0)
 | |
| 
 | |
|     def test_randn(self):
 | |
|         def f():
 | |
|             return torch.randn(3, 3)
 | |
| 
 | |
|         fx_f = symbolic_trace(f, enable_cpatching=True)
 | |
|         assert(any(i.target == torch.randn for i in fx_f.graph.nodes))
 | |
| 
 | |
|         fx_f = symbolic_trace(f, enable_cpatching=False)
 | |
|         assert(all(i.target != torch.randn for i in fx_f.graph.nodes))
 | |
| 
 | |
|         fx_f = symbolic_trace(f, enable_cpatching=True)
 | |
|         assert(any(i.target == torch.randn for i in fx_f.graph.nodes))
 | |
| 
 | |
| 
 | |
|     def test_pytree(self):
 | |
|         def f_sum(x):
 | |
|             return sum(x)
 | |
| 
 | |
|         def f_sum_dict(x):
 | |
|             out = 0
 | |
|             for k, v in x.items():
 | |
|                 out += v
 | |
|             return out
 | |
| 
 | |
|         def f_dict_list_map(x):
 | |
|             new_dict = {}
 | |
|             for k, v in x.items():
 | |
|                 new_dict[k] = [i + 1 for i in v]
 | |
|             return new_dict
 | |
| 
 | |
|         def f_dict_add(x):
 | |
|             return x['a'] + sum(x['z'])
 | |
| 
 | |
|         def f_namedtuple_add(x):
 | |
|             return x.x + x.y
 | |
| 
 | |
|         pytree._register_pytree_node(
 | |
|             Foo,
 | |
|             lambda x: ([x.a, x.b], None),
 | |
|             lambda x, _: Foo(x[0], x[1]),
 | |
|         )
 | |
|         fx_pytree.register_pytree_flatten_spec(Foo, lambda x, _: [x.a, x.b])
 | |
| 
 | |
|         def f_custom(x):
 | |
|             return x.a + x.b
 | |
| 
 | |
|         def f_custom_dict(x):
 | |
|             return f_sum_dict(x.a) + x.b
 | |
| 
 | |
|         def f_return_custom(x):
 | |
|             return Foo(x.b, x.a)
 | |
| 
 | |
|         tests = [
 | |
|             (f_sum, [PH, PH, PH]),
 | |
|             (f_sum, []),
 | |
|             (f_sum_dict, {'a': PH, 'b': PH, 'c': PH}),
 | |
|             (f_dict_list_map, {'a': (PH, PH), 'b': [PH], 'c': []}),
 | |
|             (f_dict_list_map, {5: (PH, PH, PH)}),
 | |
|             (f_dict_add, {'a': PH, 'z': (PH, PH, PH)}),
 | |
|             (f_dict_add, {'a': PH, 'z': []}),
 | |
|             (f_custom, Foo(PH, PH)),
 | |
|             (f_custom, Foo(PH, 3)),
 | |
|             (f_custom_dict, Foo({'a': PH, 'b': PH}, PH)),
 | |
|             # (f_return_custom, Foo(PH, PH)), # Don't currently support output pytrees
 | |
|             (f_namedtuple_add, Point(PH, PH)),
 | |
|         ]
 | |
| 
 | |
|         def verify_pytree(f, inp):
 | |
|             val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
 | |
|             num_flat_args = len([i == PH for i in pytree.tree_flatten(inp)[0]])
 | |
|             orig_out = f(val)
 | |
|             nf = symbolic_trace(f, concrete_args={'x': inp})
 | |
|             self.assertEqual(nf(val), orig_out)
 | |
|             assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
 | |
|             assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args)
 | |
| 
 | |
|             nf = symbolic_trace(nf)
 | |
|             self.assertEqual(nf(val), orig_out)
 | |
|             assert "tree_flatten_spec" not in nf.code
 | |
|             assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == 1)
 | |
| 
 | |
|             nf = symbolic_trace(nf, concrete_args={'x': inp})
 | |
|             self.assertEqual(nf(val), orig_out)
 | |
|             assert num_flat_args == 0 or "tree_flatten_spec" in nf.code
 | |
|             assert(sum([i.op == 'placeholder' for i in nf.graph.nodes]) == num_flat_args)
 | |
| 
 | |
|             pickled = pickle.dumps(nf)
 | |
|             nf = pickle.loads(pickled)
 | |
|             self.assertEqual(nf(val), orig_out)
 | |
| 
 | |
|         for f, inp in tests:
 | |
|             verify_pytree(f, inp)
 | |
| 
 | |
|     def test_pytree_concrete(self):
 | |
|         def f(b, a):
 | |
|             if b:
 | |
|                 return a['a']
 | |
|             else:
 | |
|                 return a['z']
 | |
| 
 | |
|         inp = {'a': {'a': PH, 'z': PH}, 'b': True}
 | |
|         nf = symbolic_trace(f, concrete_args=inp)
 | |
|         val = pytree.tree_map(lambda x: torch.randn(3) if x == PH else x, inp)
 | |
|         self.assertEqual(nf(**val), f(**val))
 | |
| 
 | |
|         nf = symbolic_trace(nf)
 | |
|         self.assertEqual(nf(**val), f(**val))
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| def run_getitem_target():
 | |
|     from torch.fx._symbolic_trace import _wrapped_methods_to_patch
 | |
|     _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__"))
 | |
|     try:
 | |
|         TestFX().getitem_inner()
 | |
|     finally:
 | |
|         _wrapped_methods_to_patch.pop()
 | |
| 
 | |
| 
 | |
| class TestOperatorSignatures(JitTestCase):
 | |
|     def setUp(self):
 | |
|         # Checking for mutable operations whil tracing is feature flagged
 | |
|         # Enable it in testing but not by default
 | |
|         self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = True
 | |
| 
 | |
|     def tearDown(self):
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
 | |
| 
 | |
|     @onlyCPU
 | |
|     @ops(op_db, allowed_dtypes=(torch.float,))
 | |
|     def test_get_torch_func_signature_exhaustive(self, device, dtype, op):
 | |
|         # Sorted and one entry on each line to minimize merge conflicts.
 | |
|         known_no_schema = {'block_diag',
 | |
|                            'broadcast_tensors',
 | |
|                            'cdist',
 | |
|                            'contiguous',
 | |
|                            'dstack',
 | |
|                            'einsum',
 | |
|                            'expand',
 | |
|                            'expand_as',
 | |
|                            'fill_',
 | |
|                            'hstack',
 | |
|                            'igamma',
 | |
|                            'igammac',
 | |
|                            'linalg.multi_dot',
 | |
|                            'lu',
 | |
|                            'T',   # Implemented with a lambda
 | |
|                            'H',   # Implemented with a lambda
 | |
|                            'mT',  # Implemented with a lambda
 | |
|                            'mH',  # Implemented with a lambda
 | |
|                            'norm',
 | |
|                            'polygamma',
 | |
|                            'special.polygamma',
 | |
|                            'repeat',
 | |
|                            'reshape_as',
 | |
|                            'resize_',
 | |
|                            'resize_as_',
 | |
|                            'special.zeta',
 | |
|                            'stack',
 | |
|                            'to_sparse',
 | |
|                            'view',
 | |
|                            'view_as',
 | |
|                            'nn.functional.hardshrink',
 | |
|                            'vstack',
 | |
|                            'where',
 | |
|                            'zero_',
 | |
|                            'bfloat16',
 | |
|                            'bool',
 | |
|                            'byte',
 | |
|                            'char',
 | |
|                            'double',
 | |
|                            'float',
 | |
|                            'half',
 | |
|                            'int',
 | |
|                            'long',
 | |
|                            'short',
 | |
|                            'empty_like',
 | |
|                            'ones_like',
 | |
|                            'randn_like',
 | |
|                            'zeros_like',
 | |
|                            'full_like',
 | |
|                            '__getitem__',
 | |
|                            '__radd__',
 | |
|                            '__rsub__',
 | |
|                            '__rmul__',
 | |
|                            '__rdiv__',
 | |
|                            '__rmod__',
 | |
|                            '__rpow__',
 | |
|                            '__rand__',
 | |
|                            '__ror__',
 | |
|                            '__rxor__',
 | |
|                            '__rmatmul__'}
 | |
| 
 | |
|         try:
 | |
|             sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
 | |
|             schemas = get_signature_for_torch_op(op.op)
 | |
|             if not schemas:
 | |
|                 raise RuntimeError('No Schemas Returned')
 | |
|             for sample_input in sample_inputs_itr:
 | |
|                 # Iterate through overloads until we hit a match. If we exit this
 | |
|                 # loop via `else`, we haven't found a match
 | |
|                 for schema in schemas:
 | |
|                     try:
 | |
|                         bound_args = schema.bind(sample_input.input, *sample_input.args, **sample_input.kwargs)
 | |
|                         bound_args.apply_defaults()
 | |
|                         op(*bound_args.args, **bound_args.kwargs)
 | |
|                         break
 | |
|                     except TypeError as e:
 | |
|                         pass
 | |
|                 else:
 | |
|                     raise RuntimeError(f'Did not match any schemas for op {op.name}!')
 | |
| 
 | |
|         except Exception as e:
 | |
|             assert op.name in known_no_schema or "nn.functional" in op.name
 | |
| 
 | |
| 
 | |
| class TestFXAPIBackwardCompatibility(JitTestCase):
 | |
|     def setUp(self):
 | |
|         self.maxDiff = None
 | |
| 
 | |
|         # Checking for mutable operations whil tracing is feature flagged
 | |
|         # Enable it in testing but not by default
 | |
|         self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = True
 | |
| 
 | |
|     def tearDown(self):
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
 | |
| 
 | |
| 
 | |
|     def _fn_to_stable_annotation_str(self, obj):
 | |
|         """
 | |
|         Unfortunately we have to serialize function signatures manually since
 | |
|         serialization for `inspect.Signature` objects is not stable across
 | |
|         python versions
 | |
|         """
 | |
|         fn_name = torch.typename(obj)
 | |
| 
 | |
|         signature = inspect.signature(obj)
 | |
| 
 | |
|         sig_str = f'{fn_name}{signature}'
 | |
| 
 | |
|         arg_strs = []
 | |
|         for k, v in signature.parameters.items():
 | |
|             maybe_type_annotation = f': {self._annotation_type_to_stable_str(v.annotation, sig_str)}'\
 | |
|                 if v.annotation is not inspect.Signature.empty else ''
 | |
| 
 | |
|             def default_val_str(val):
 | |
|                 if isinstance(val, (tuple, list)):
 | |
|                     str_pieces = ['(' if isinstance(val, tuple) else '[']
 | |
|                     str_pieces.append(', '.join(default_val_str(v) for v in val))
 | |
|                     if isinstance(val, tuple) and len(str_pieces) == 2:
 | |
|                         str_pieces.append(',')
 | |
|                     str_pieces.append(')' if isinstance(val, tuple) else ']')
 | |
|                     return ''.join(str_pieces)
 | |
| 
 | |
|                 # Need to fix up some default value strings.
 | |
|                 # First case: modules. Default module `repr` contains the FS path of the module.
 | |
|                 # Don't leak that
 | |
|                 if isinstance(val, types.ModuleType):
 | |
|                     return f'<module {val.__name__}>'
 | |
| 
 | |
|                 # Second case: callables. Callables (such as lambdas) encode their address in
 | |
|                 # their string repr. Don't do that
 | |
|                 if callable(val):
 | |
|                     return f'<function {val.__name__}>'
 | |
| 
 | |
|                 return str(val)
 | |
| 
 | |
|             if v.default is not inspect.Signature.empty:
 | |
|                 default_val_str = default_val_str(v.default) if not isinstance(v.default, str) else f"'{v.default}'"
 | |
|                 maybe_default = f' = {default_val_str}'
 | |
|             else:
 | |
|                 maybe_default = ''
 | |
|             maybe_stars = ''
 | |
|             if v.kind == inspect.Parameter.VAR_POSITIONAL:
 | |
|                 maybe_stars = '*'
 | |
|             elif v.kind == inspect.Parameter.VAR_KEYWORD:
 | |
|                 maybe_stars = '**'
 | |
|             arg_strs.append(f'{maybe_stars}{k}{maybe_type_annotation}{maybe_default}')
 | |
| 
 | |
|         return_annot = f' -> {self._annotation_type_to_stable_str(signature.return_annotation, sig_str)}'\
 | |
|             if signature.return_annotation is not inspect.Signature.empty else ''
 | |
| 
 | |
|         return f'{fn_name}({", ".join(arg_strs)}){return_annot}'
 | |
| 
 | |
|     def _annotation_type_to_stable_str(self, t, sig_str):
 | |
|         if t is inspect.Signature.empty:
 | |
|             return ''
 | |
| 
 | |
|         # Forward ref
 | |
|         if isinstance(t, str):
 | |
|             return f"'{t}'"
 | |
|         if hasattr(typing, 'ForwardRef') and isinstance(t, typing.ForwardRef):
 | |
|             return t.__forward_arg__
 | |
|         if hasattr(typing, '_ForwardRef') and isinstance(t, typing._ForwardRef):
 | |
|             return t.__forward_arg__
 | |
| 
 | |
|         trivial_mappings = {
 | |
|             str : 'str',
 | |
|             int : 'int',
 | |
|             float: 'float',
 | |
|             bool: 'bool',
 | |
|             torch.dtype: 'torch.dtype',
 | |
|             torch.Tensor: 'torch.Tensor',
 | |
|             torch.device: 'torch.device',
 | |
|             torch.memory_format: 'torch.memory_format',
 | |
|             slice: 'slice',
 | |
|             torch.nn.Module: 'torch.nn.modules.module.Module',
 | |
|             torch.fx.Graph : 'torch.fx.graph.Graph',
 | |
|             torch.fx.Node : 'torch.fx.node.Node',
 | |
|             torch.fx.Proxy : 'torch.fx.proxy.Proxy',
 | |
|             torch.fx.node.Target : 'torch.fx.node.Target',
 | |
|             torch.fx.node.Argument : 'torch.fx.node.Argument',
 | |
|             torch.fx.graph.PythonCode : 'torch.fx.graph.PythonCode',
 | |
|             torch.fx.graph_module.GraphModule: 'torch.fx.graph_module.GraphModule',
 | |
|             torch.fx.subgraph_rewriter.Match: 'torch.fx.subgraph_rewriter.Match',
 | |
|             Ellipsis : '...',
 | |
|             typing.Any: 'Any',
 | |
|             type(None): 'NoneType',
 | |
|             None: 'None',
 | |
|             typing.Iterator: 'Iterator',
 | |
|         }
 | |
| 
 | |
|         mapping = trivial_mappings.get(t, None)
 | |
|         if mapping:
 | |
|             return mapping
 | |
| 
 | |
|         # Handle types with contained types
 | |
|         contained = getattr(t, '__args__', None) or []
 | |
| 
 | |
|         # Callables contain a bare List for arguments
 | |
|         contained = t if isinstance(t, list) else contained
 | |
| 
 | |
|         # Python 3.8 puts type vars into __args__ for unbound types such as Dict
 | |
|         if all(isinstance(ct, typing.TypeVar) for ct in contained):
 | |
|             contained = []
 | |
| 
 | |
|         contained_type_annots = [self._annotation_type_to_stable_str(ct, sig_str) for ct in contained]
 | |
|         contained_type_str = f'[{", ".join(contained_type_annots)}]' if len(contained_type_annots) > 0 else ''
 | |
| 
 | |
| 
 | |
|         origin = getattr(t, '__origin__', None)
 | |
|         if origin is None:
 | |
|             # Unbound types don't have `__origin__` in some Python versions, so fix that up here.
 | |
|             origin = t if t in {typing.Tuple, typing.Union, typing.Dict, typing.List, typing.Type, typing.Callable} else origin
 | |
| 
 | |
|         if origin in {tuple, typing.Tuple}:
 | |
|             return f'Tuple{contained_type_str}'
 | |
|         if origin in {typing.Union}:
 | |
|             # Annoying hack to detect Optional
 | |
|             if len(contained) == 2 and (contained[0] is type(None)) ^ (contained[1] is type(None)):
 | |
|                 not_none_param = contained[0] if contained[0] is not type(None) else contained[1]
 | |
|                 return f'Optional[{self._annotation_type_to_stable_str(not_none_param, sig_str)}]'
 | |
|             return f'Union{contained_type_str}'
 | |
|         if origin in {dict, typing.Dict}:
 | |
|             return f'Dict{contained_type_str}'
 | |
|         if origin in {list, typing.List}:
 | |
|             return f'List{contained_type_str}'
 | |
|         if origin in {type, typing.Type}:
 | |
|             return f'Type{contained_type_str}'
 | |
|         if isinstance(t, typing.Callable):
 | |
|             if len(contained) > 0 and contained[0] is not Ellipsis:
 | |
|                 return f'Callable[[{", ".join(contained_type_annots[:-1])}], {contained_type_annots[-1]}]'
 | |
|             else:
 | |
|                 return f'Callable{contained_type_str}'
 | |
| 
 | |
|         raise RuntimeError(f'Unrecognized type {t} used in BC-compatible type signature {sig_str}.'
 | |
|                            f'Please add support for this type and confirm with the '
 | |
|                            f'FX team that your signature change is valid.')
 | |
| 
 | |
| 
 | |
|     def test_function_back_compat(self):
 | |
|         """
 | |
|         Test backward compatibility for function signatures with
 | |
|         @compatibility(is_backward_compatible=True). Currently this checks for
 | |
|         exact signature matches, which may lead to false positives. If this
 | |
|         becomes too annoying, we can refine this check to actually parse out
 | |
|         the saved schema strings and check if the change is truly backward-
 | |
|         incompatible.
 | |
|         """
 | |
|         signature_strs = []
 | |
| 
 | |
|         for obj in _BACK_COMPAT_OBJECTS:
 | |
|             if not isinstance(obj, type):
 | |
|                 signature_strs.append(self._fn_to_stable_annotation_str(obj))
 | |
| 
 | |
|         signature_strs.sort()
 | |
| 
 | |
|         try:
 | |
|             self.assertExpected('\n'.join(signature_strs), 'fx_backcompat_function_signatures')
 | |
|         except AssertionError as e:
 | |
|             msg = f"{e}\n****** ERROR ******\nAn FX function that has been marked " \
 | |
|                   f"as backwards-compatible has experienced a signature change. See the " \
 | |
|                   f"above exception context for more information. If this change was " \
 | |
|                   f"unintended, please revert it. If it was intended, check with the FX " \
 | |
|                   f"team to ensure that the proper deprecation protocols have been followed " \
 | |
|                   f"and subsequently --accept the change."
 | |
|             raise AssertionError(msg)
 | |
| 
 | |
|     def test_class_member_back_compat(self):
 | |
|         """
 | |
|         Test backward compatibility for members of classes with
 | |
|         @compatibility(is_backward_compatible=True). Currently this checks for
 | |
|         exact matches on the publicly visible members of the class.
 | |
|         """
 | |
|         class_method_strs = []
 | |
| 
 | |
|         for obj in _BACK_COMPAT_OBJECTS:
 | |
|             if isinstance(obj, type):
 | |
|                 public_members = [name for name in obj.__dict__ if not name.startswith('_')]
 | |
|                 class_method_strs.append(f'{torch.typename(obj)} {sorted(public_members)}')
 | |
| 
 | |
|         class_method_strs.sort()
 | |
| 
 | |
|         try:
 | |
|             self.assertExpected('\n'.join(class_method_strs), 'fx_backcompat_class_members')
 | |
|         except AssertionError as e:
 | |
|             msg = f"{e}\n****** ERROR ******\nAn FX class that has been marked " \
 | |
|                   f"as backwards-compatible has experienced change in its public members. See the " \
 | |
|                   f"above exception context for more information. If this change was " \
 | |
|                   f"unintended, please revert it. If it was intended, check with the FX " \
 | |
|                   f"team to ensure that the proper deprecation protocols have been followed " \
 | |
|                   f"and subsequently --accept the change."
 | |
|             raise AssertionError(msg)
 | |
| 
 | |
|     def test_public_api_surface(self):
 | |
|         non_back_compat_objects = {}
 | |
| 
 | |
|         def check_symbols_have_bc_designation(m, prefix):
 | |
|             if not m.__name__.startswith('torch.fx'):
 | |
|                 return
 | |
|             if m.__name__.startswith('torch.fx.experimental'):
 | |
|                 return
 | |
|             for k, v in m.__dict__.items():
 | |
|                 if v is m:
 | |
|                     continue
 | |
|                 if k.startswith('_'):
 | |
|                     continue
 | |
|                 if isinstance(v, types.ModuleType):
 | |
|                     check_symbols_have_bc_designation(v, prefix + [k])
 | |
|                 elif isinstance(v, type) or isinstance(v, types.FunctionType):
 | |
|                     if v not in _MARKED_WITH_COMATIBLITY:
 | |
|                         non_back_compat_objects.setdefault(v)
 | |
| 
 | |
|         check_symbols_have_bc_designation(torch.fx, ['torch', 'fx'])
 | |
|         check_symbols_have_bc_designation(torch.fx.passes, ['torch', 'fx', 'passes'])
 | |
| 
 | |
|         non_back_compat_strs = [torch.typename(obj) for obj in non_back_compat_objects.keys()]
 | |
|         # Only want objects in torch.fx
 | |
|         non_back_compat_strs = [
 | |
|             s for s in non_back_compat_strs if s.startswith('torch.fx') and not s.startswith('torch.fx.experimental')]
 | |
|         # Only want objects in public namespaces
 | |
|         non_back_compat_strs = [
 | |
|             s for s in non_back_compat_strs if all(not atom.startswith('_') for atom in s.split('.'))]
 | |
|         non_back_compat_strs.sort()
 | |
| 
 | |
|         if len(non_back_compat_strs) != 0:
 | |
|             raise AssertionError(f"Public FX API(s) {non_back_compat_strs} introduced but not given a "
 | |
|                                  f"backwards-compatibility classification! Please decorate these "
 | |
|                                  f"API(s) with `@torch.fx._compatibility.compatibility` to specify "
 | |
|                                  f"BC guarantees.")
 | |
| 
 | |
| class TestFunctionalTracing(JitTestCase):
 | |
|     def setUp(self):
 | |
|         # Checking for mutable operations whil tracing is feature flagged
 | |
|         # Enable it in testing but not by default
 | |
|         self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = True
 | |
| 
 | |
|     def tearDown(self):
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
 | |
| 
 | |
|     IGNORE_FUNCS = ("has_torch_function", "has_torch_function_unary",
 | |
|                     "has_torch_function_variadic", "handle_torch_function",
 | |
|                     "boolean_dispatch")
 | |
|     TO_PATCH = {"has_torch_function": None,
 | |
|                 "has_torch_function_unary": None,
 | |
|                 "has_torch_function_variadic": None}
 | |
| 
 | |
|     BUILT_IN_FUNC = (AssertionError, "")
 | |
|     PROXY_ITERABLE = (TypeError, r"argument of type 'Proxy' is not iterable")
 | |
|     PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
 | |
|     LEN_ERROR = (RuntimeError, r"'len' is not supported in symbolic tracing by default")
 | |
|     ARG_TYPE_MISMATCH = (TypeError, r", not Proxy$")
 | |
|     CONTROL_FLOW = (TraceError, r"symbolically traced variables cannot be used as inputs to control flow")
 | |
|     INTERPOLATE_ARGS_CONFLICT = (ValueError, r"only one of size or scale_factor should be defined")
 | |
|     MUTABLE = (RuntimeError, r"Tried to trace mutable operation")
 | |
| 
 | |
|     UNTRACEABLE_FUNCTIONALS = {
 | |
|         "adaptive_avg_pool1d": BUILT_IN_FUNC,
 | |
|         "avg_pool1d": BUILT_IN_FUNC,
 | |
|         "avg_pool2d": BUILT_IN_FUNC,
 | |
|         "avg_pool3d": BUILT_IN_FUNC,
 | |
|         "celu_": BUILT_IN_FUNC,
 | |
|         "channel_shuffle": BUILT_IN_FUNC,
 | |
|         "conv1d": BUILT_IN_FUNC,
 | |
|         "conv2d": BUILT_IN_FUNC,
 | |
|         "conv3d": BUILT_IN_FUNC,
 | |
|         "conv_tbc": BUILT_IN_FUNC,
 | |
|         "conv_transpose1d": BUILT_IN_FUNC,
 | |
|         "conv_transpose2d": BUILT_IN_FUNC,
 | |
|         "conv_transpose3d": BUILT_IN_FUNC,
 | |
|         "cosine_similarity": BUILT_IN_FUNC,
 | |
|         "elu_": BUILT_IN_FUNC,
 | |
|         "hardtanh_": BUILT_IN_FUNC,
 | |
|         "leaky_relu_": BUILT_IN_FUNC,
 | |
|         "logsigmoid": BUILT_IN_FUNC,
 | |
|         "one_hot": BUILT_IN_FUNC,
 | |
|         "pdist": BUILT_IN_FUNC,
 | |
|         "pixel_shuffle": BUILT_IN_FUNC,
 | |
|         "pixel_unshuffle": BUILT_IN_FUNC,
 | |
|         "relu_": BUILT_IN_FUNC,
 | |
|         "rrelu_": BUILT_IN_FUNC,
 | |
|         "selu_": BUILT_IN_FUNC,
 | |
|         "softplus": BUILT_IN_FUNC,
 | |
|         "softshrink": BUILT_IN_FUNC,
 | |
|         "threshold_": BUILT_IN_FUNC,
 | |
| 
 | |
|         "adaptive_avg_pool2d": LEN_ERROR,
 | |
|         "adaptive_avg_pool3d": LEN_ERROR,
 | |
|         "adaptive_max_pool2d_with_indices": LEN_ERROR,
 | |
|         "adaptive_max_pool3d_with_indices": LEN_ERROR,
 | |
|         "instance_norm": CONTROL_FLOW,
 | |
|         "pad": LEN_ERROR,
 | |
| 
 | |
|         "adaptive_max_pool1d": PROXY_ITERABLE,
 | |
|         "adaptive_max_pool2d": PROXY_ITERABLE,
 | |
|         "adaptive_max_pool3d": PROXY_ITERABLE,
 | |
|         "fractional_max_pool2d": PROXY_ITERABLE,
 | |
|         "fractional_max_pool3d": PROXY_ITERABLE,
 | |
|         "max_pool1d": PROXY_ITERABLE,
 | |
|         "max_pool2d": PROXY_ITERABLE,
 | |
|         "max_pool3d": PROXY_ITERABLE,
 | |
| 
 | |
|         "group_norm": PROXY_ITERATED,
 | |
|         "lp_pool2d": PROXY_ITERATED,
 | |
|         "max_unpool1d": PROXY_ITERATED,
 | |
|         "max_unpool2d": PROXY_ITERATED,
 | |
|         "max_unpool3d": PROXY_ITERATED,
 | |
| 
 | |
|         "adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
 | |
|         "fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
 | |
|         "fractional_max_pool3d_with_indices": ARG_TYPE_MISMATCH,
 | |
|         "hardshrink": ARG_TYPE_MISMATCH,
 | |
|         "layer_norm": ARG_TYPE_MISMATCH,
 | |
|         "lp_pool1d": ARG_TYPE_MISMATCH,
 | |
|         "max_pool1d_with_indices": ARG_TYPE_MISMATCH,
 | |
|         "max_pool2d_with_indices": ARG_TYPE_MISMATCH,
 | |
|         "max_pool3d_with_indices": ARG_TYPE_MISMATCH,
 | |
|         "pairwise_distance": ARG_TYPE_MISMATCH,
 | |
| 
 | |
|         "affine_grid": CONTROL_FLOW,
 | |
|         "alpha_dropout": CONTROL_FLOW,
 | |
|         "batch_norm": CONTROL_FLOW,
 | |
|         "binary_cross_entropy": CONTROL_FLOW,
 | |
|         "binary_cross_entropy_with_logits": CONTROL_FLOW,
 | |
|         "celu": CONTROL_FLOW,
 | |
|         "cosine_embedding_loss": CONTROL_FLOW,
 | |
|         "cross_entropy": CONTROL_FLOW,
 | |
|         "ctc_loss": CONTROL_FLOW,
 | |
|         "dropout": CONTROL_FLOW,
 | |
|         "dropout2d": CONTROL_FLOW,
 | |
|         "dropout3d": CONTROL_FLOW,
 | |
|         "elu": CONTROL_FLOW,
 | |
|         "embedding": CONTROL_FLOW,
 | |
|         "embedding_bag": CONTROL_FLOW,
 | |
|         "feature_alpha_dropout": CONTROL_FLOW,
 | |
|         "fold": CONTROL_FLOW,
 | |
|         "gaussian_nll_loss": CONTROL_FLOW,
 | |
|         "glu": CONTROL_FLOW,
 | |
|         "grid_sample": CONTROL_FLOW,
 | |
|         "gumbel_softmax": CONTROL_FLOW,
 | |
|         "hardsigmoid": CONTROL_FLOW,
 | |
|         "hardswish": CONTROL_FLOW,
 | |
|         "hardtanh": CONTROL_FLOW,
 | |
|         "hinge_embedding_loss": CONTROL_FLOW,
 | |
|         "huber_loss": CONTROL_FLOW,
 | |
|         "interpolate": CONTROL_FLOW,
 | |
|         "kl_div": CONTROL_FLOW,
 | |
|         "l1_loss": CONTROL_FLOW,
 | |
|         "leaky_relu": CONTROL_FLOW,
 | |
|         "local_response_norm": CONTROL_FLOW,
 | |
|         "margin_ranking_loss": CONTROL_FLOW,
 | |
|         "mse_loss": CONTROL_FLOW,
 | |
|         "multi_head_attention_forward": CONTROL_FLOW,
 | |
|         "multi_margin_loss": CONTROL_FLOW,
 | |
|         "multilabel_margin_loss": CONTROL_FLOW,
 | |
|         "multilabel_soft_margin_loss": CONTROL_FLOW,
 | |
|         "nll_loss": CONTROL_FLOW,
 | |
|         "poisson_nll_loss": CONTROL_FLOW,
 | |
|         "relu": CONTROL_FLOW,
 | |
|         "relu6": CONTROL_FLOW,
 | |
|         "rrelu": CONTROL_FLOW,
 | |
|         "selu": CONTROL_FLOW,
 | |
|         "silu": CONTROL_FLOW,
 | |
|         "mish": CONTROL_FLOW,
 | |
|         "smooth_l1_loss": CONTROL_FLOW,
 | |
|         "soft_margin_loss": CONTROL_FLOW,
 | |
|         "threshold": CONTROL_FLOW,
 | |
|         "triplet_margin_loss": CONTROL_FLOW,
 | |
|         "triplet_margin_with_distance_loss": CONTROL_FLOW,
 | |
|         "unfold": CONTROL_FLOW,
 | |
|         "upsample": CONTROL_FLOW,
 | |
| 
 | |
|         "upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
 | |
|         "upsample_nearest": INTERPOLATE_ARGS_CONFLICT,
 | |
| 
 | |
|         "normalize" : MUTABLE,
 | |
|     }
 | |
| 
 | |
|     # List of nn.functionals with Tensor inputs but not with type annotation
 | |
|     FUNCTIONALS_WITHOUT_ANNOTATION = (
 | |
|         "adaptive_max_pool1d",
 | |
|         "adaptive_max_pool2d",
 | |
|         "adaptive_max_pool3d",
 | |
|         "fractional_max_pool2d",
 | |
|         "fractional_max_pool3d",
 | |
|         "max_pool1d",
 | |
|         "max_pool2d",
 | |
|         "max_pool3d",
 | |
|         "gaussian_nll_loss",
 | |
|         "upsample",
 | |
|         "upsample_bilinear",
 | |
|         "upsample_nearest",
 | |
|     )
 | |
| 
 | |
|     # Inconsistent behavior between Python 3.8 and other Python versions:
 | |
|     # - Python 3.8+: Re-raise internal exception like `PROXY_ITERATED`
 | |
|     # - Other Python: Raise `argument of type 'Proxy' is not iterable` due to the same
 | |
|     #                 internal exception above
 | |
|     # Use the following map to override the expected exception for Python 3.8
 | |
|     UNTRACEABLE_FUNCTIONALS_PY38 = {
 | |
|         "adaptive_max_pool1d": PROXY_ITERATED,
 | |
|         "adaptive_max_pool2d": PROXY_ITERATED,
 | |
|         "adaptive_max_pool3d": PROXY_ITERATED,
 | |
|         "fractional_max_pool2d": PROXY_ITERATED,
 | |
|         "fractional_max_pool3d": PROXY_ITERATED,
 | |
|         "max_pool1d": PROXY_ITERATED,
 | |
|         "max_pool2d": PROXY_ITERATED,
 | |
|         "max_pool3d": PROXY_ITERATED,
 | |
| 
 | |
|         "group_norm": LEN_ERROR
 | |
|     }
 | |
| 
 | |
|     @classmethod
 | |
|     def _get_functional(cls):
 | |
|         functional_list = []
 | |
|         for f in dir(torch.nn.functional):
 | |
|             if not f.islower():
 | |
|                 continue
 | |
|             # Ignore internal functions
 | |
|             if f.startswith('_'):
 | |
|                 continue
 | |
|             # Ignore supporting functions
 | |
|             if f in cls.IGNORE_FUNCS:
 | |
|                 continue
 | |
|             fn = getattr(torch.nn.functional, f)
 | |
|             # Ignore non-callable object like modules
 | |
|             if not isinstance(fn, Callable):
 | |
|                 continue
 | |
|             if f not in cls.FUNCTIONALS_WITHOUT_ANNOTATION:
 | |
|                 try:
 | |
|                     sig = inspect.signature(fn)
 | |
|                     has_tensor_arg = False
 | |
|                     for arg, param in sig.parameters.items():
 | |
|                         if isinstance(param.annotation, type) and issubclass(param.annotation, torch.Tensor):
 | |
|                             has_tensor_arg = True
 | |
|                     if not has_tensor_arg:
 | |
|                         continue
 | |
|                 # No signature or Object is not supported
 | |
|                 except ValueError:
 | |
|                     pass
 | |
|             functional_list.append((f, fn))
 | |
|         return functional_list
 | |
| 
 | |
|     @classmethod
 | |
|     def generate_test_func(cls, func_name, fn):
 | |
| 
 | |
|         def functional_test(self):
 | |
|             if func_name in self.UNTRACEABLE_FUNCTIONALS_PY38 and \
 | |
|                     sys.version_info >= (3, 8) and sys.version_info < (3, 10):
 | |
|                 exc, err = self.UNTRACEABLE_FUNCTIONALS_PY38[func_name]
 | |
|                 with self.assertRaisesRegex(exc, err):
 | |
|                     symbolic_trace(fn)
 | |
|             elif func_name in self.UNTRACEABLE_FUNCTIONALS:
 | |
|                 exc, err = self.UNTRACEABLE_FUNCTIONALS[func_name]
 | |
|                 with self.assertRaisesRegex(exc, err):
 | |
|                     symbolic_trace(fn)
 | |
|             else:
 | |
|                 symbolic_trace(fn)
 | |
|         return functional_test
 | |
| 
 | |
|     @classmethod
 | |
|     def generate_tests(cls):
 | |
|         functional_list = cls._get_functional()
 | |
|         for func_name, fn in functional_list:
 | |
|             test_name = "test_nn_functional_" + func_name
 | |
|             functional_test = cls.generate_test_func(func_name, fn)
 | |
|             setattr(cls, test_name, functional_test)
 | |
| 
 | |
|     @classmethod
 | |
|     def setUpClass(cls):
 | |
| 
 | |
|         def no(*args, **kwargs):
 | |
|             return False
 | |
| 
 | |
|         for name in cls.TO_PATCH.keys():
 | |
|             cls.TO_PATCH[name] = getattr(torch.nn.functional, name)
 | |
|             setattr(torch.nn.functional, name, no)
 | |
| 
 | |
|     @classmethod
 | |
|     def tearDownClass(cls):
 | |
|         for name in cls.TO_PATCH.keys():
 | |
|             setattr(torch.nn.functional, name, cls.TO_PATCH[name])
 | |
| 
 | |
| TestFunctionalTracing.generate_tests()
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestOperatorSignatures, globals())
 | |
| 
 | |
| @skipIfNoTorchVision
 | |
| class TestVisionTracing(JitTestCase):
 | |
|     def setUp(self):
 | |
|         # Checking for mutable operations whil tracing is feature flagged
 | |
|         # Enable it in testing but not by default
 | |
|         self.orig_tracer_mutable_flag = torch.fx.proxy.TracerBase.check_mutable_operations
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = True
 | |
| 
 | |
|     def tearDown(self):
 | |
|         torch.fx.proxy.TracerBase.check_mutable_operations = self.orig_tracer_mutable_flag
 | |
| 
 | |
|     PROXY_ITERATED = (TraceError, r"Proxy object cannot be iterated")
 | |
|     INCONSISTENT_TYPE = (
 | |
|         RuntimeError,
 | |
|         r"Return value was annotated as having type __torch__.torchvision.models[.\w]+ but is actually of type Tensor"
 | |
|     )
 | |
| 
 | |
|     UNTRACEABLE_MODELS = {
 | |
|         "fasterrcnn_resnet50_fpn": PROXY_ITERATED,
 | |
|         "fasterrcnn_mobilenet_v3_large_320_fpn": PROXY_ITERATED,
 | |
|         "fasterrcnn_mobilenet_v3_large_fpn": PROXY_ITERATED,
 | |
|         "maskrcnn_resnet50_fpn": PROXY_ITERATED,
 | |
|         "keypointrcnn_resnet50_fpn": PROXY_ITERATED,
 | |
|         "retinanet_resnet50_fpn": PROXY_ITERATED,
 | |
|     }
 | |
|     UNSCRIPTABLE_MODELS = {
 | |
|         "googlenet": INCONSISTENT_TYPE,
 | |
|         "inception_v3": INCONSISTENT_TYPE,
 | |
|     }
 | |
| 
 | |
|     output_transform = {
 | |
|         "fcn_resnet50": lambda x: x["out"],
 | |
|         "fcn_resnet101": lambda x: x["out"],
 | |
|         "deeplabv3_resnet50": lambda x: x["out"],
 | |
|         "deeplabv3_resnet101": lambda x: x["out"],
 | |
|         "deeplabv3_mobilenet_v3_large": lambda x: x["out"],
 | |
|         "lraspp_mobilenet_v3_large": lambda x: x["out"],
 | |
|         "fasterrcnn_resnet50_fpn": lambda x: x[1],
 | |
|         "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
 | |
|         "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
 | |
|         "maskrcnn_resnet50_fpn": lambda x: x[1],
 | |
|         "keypointrcnn_resnet50_fpn": lambda x: x[1],
 | |
|         "retinanet_resnet50_fpn": lambda x: x[1],
 | |
|     }
 | |
| 
 | |
|     @classmethod
 | |
|     def generate_test_fn(cls, name, model_fn, x, kwargs):
 | |
|         def run_test(self):
 | |
|             model = model_fn(**kwargs)
 | |
|             model = model.eval()
 | |
|             if name in self.UNTRACEABLE_MODELS:
 | |
|                 err, exc = self.UNTRACEABLE_MODELS[name]
 | |
|                 with self.assertRaisesRegex(err, exc):
 | |
|                     graph = symbolic_trace(model)
 | |
|             else:
 | |
|                 out_transform = self.output_transform.get(name, lambda x: x)
 | |
|                 graph : torch.fx.GraphModule = symbolic_trace(model)
 | |
|                 a = out_transform(model(x))
 | |
|                 b = out_transform(graph(x))
 | |
|                 self.assertEqual(a, b)
 | |
| 
 | |
|                 if name in self.UNSCRIPTABLE_MODELS:
 | |
|                     err, exc = self.UNSCRIPTABLE_MODELS[name]
 | |
|                     with self.assertRaisesRegex(err, exc):
 | |
|                         script = torch.jit.script(graph)
 | |
|                 else:
 | |
|                     script = torch.jit.script(graph)
 | |
|                     c = out_transform(script(x))
 | |
|                     self.assertEqual(a, c)
 | |
| 
 | |
|         return run_test
 | |
| 
 | |
|     @classmethod
 | |
|     def generate_classification_tests(cls):
 | |
|         for k, v in torchvision_models.__dict__.items():
 | |
|             if callable(v) and k[0].lower() == k[0] and k[0] != "_":
 | |
|                 test_name = 'test_torchvision_models_' + k
 | |
|                 x = torch.rand(1, 3, 299, 299) if k in ['inception_v3'] else torch.rand(1, 3, 224, 224)
 | |
|                 kwargs = dict(num_classes=50)
 | |
|                 model_test = cls.generate_test_fn(k, v, x, kwargs)
 | |
|                 setattr(cls, test_name, model_test)
 | |
| 
 | |
|     @classmethod
 | |
|     def generate_segmentation_tests(cls):
 | |
|         for k, v in torchvision_models.segmentation.__dict__.items():
 | |
|             if callable(v) and k[0].lower() == k[0] and k[0] != "_":
 | |
|                 test_name = 'test_torchvision_models_segmentation_' + k
 | |
|                 x = torch.rand(1, 3, 32, 32)
 | |
|                 kwargs = dict(num_classes=10, pretrained_backbone=False)
 | |
|                 model_test = cls.generate_test_fn(k, v, x, kwargs)
 | |
|                 setattr(cls, test_name, model_test)
 | |
| 
 | |
|     @classmethod
 | |
|     def generate_detection_tests(cls):
 | |
|         for k, v in torchvision_models.detection.__dict__.items():
 | |
|             if callable(v) and k[0].lower() == k[0] and k[0] != "_":
 | |
|                 test_name = 'test_torchvision_models_detection_' + k
 | |
|                 x = [torch.rand(3, 300, 300)]
 | |
|                 kwargs = dict(num_classes=10, pretrained_backbone=False)
 | |
|                 model_test = cls.generate_test_fn(k, v, x, kwargs)
 | |
|                 setattr(cls, test_name, model_test)
 | |
| 
 | |
|     @classmethod
 | |
|     def generate_video_tests(cls):
 | |
|         for k, v in torchvision_models.video.__dict__.items():
 | |
|             if callable(v) and k[0].lower() == k[0] and k[0] != "_":
 | |
|                 test_name = 'test_torchvision_models_video_' + k
 | |
|                 x = torch.rand(1, 3, 4, 112, 112)
 | |
|                 kwargs = dict(num_classes=50)
 | |
|                 model_test = cls.generate_test_fn(k, v, x, kwargs)
 | |
|                 setattr(cls, test_name, model_test)
 | |
| 
 | |
|     @classmethod
 | |
|     def generate_tests(cls):
 | |
|         cls.generate_classification_tests()
 | |
|         cls.generate_detection_tests()
 | |
|         cls.generate_segmentation_tests()
 | |
|         cls.generate_video_tests()
 | |
| 
 | |
| if HAS_TORCHVISION:
 | |
|     TestVisionTracing.generate_tests()
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     run_tests()
 |