Enable UFMT on all of test/fx (#123622)

Partially addresses #123062

Ran lintrunner on:

- `test/fx`

with command:

```bash
lintrunner -a --take UFMT --all-files
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123622
Approved by: https://github.com/ezyang
This commit is contained in:
Yuanhao Ji
2024-04-09 15:59:15 +00:00
committed by PyTorch MergeBot
parent 3b3962f7b3
commit c96bd3de06
14 changed files with 1216 additions and 567 deletions

View File

@ -1178,21 +1178,6 @@ exclude_patterns = [
'test/functorch/test_vmap.py',
'test/functorch/test_vmap_registrations.py',
'test/functorch/xfail_suggester.py',
'test/fx/named_tup.py',
'test/fx/quantization.py',
'test/fx/test_common_passes.py',
'test/fx/test_cse_pass.py',
'test/fx/test_dce_pass.py',
'test/fx/test_future.py',
'test/fx/test_fx_const_fold.py',
'test/fx/test_fx_param_shape_control_flow.py',
'test/fx/test_gradual_type.py',
'test/fx/test_matcher_utils.py',
'test/fx/test_pass_infra.py',
'test/fx/test_source_matcher_utils.py',
'test/fx/test_subgraph_rewriter.py',
'test/fx/test_z3_gradual_types.py',
'test/fx/test_fx_split.py',
'test/jit/__init__.py',
'test/jit/_imported_class_test/__init__.py',
'test/jit/_imported_class_test/bar.py',

View File

@ -2,6 +2,7 @@ from typing import NamedTuple
import torch
class MyNamedTup(NamedTuple):
i : torch.Tensor
f : torch.Tensor
i: torch.Tensor
f: torch.Tensor

View File

@ -1,20 +1,24 @@
r'''
r"""
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
rely on it for anything!**
'''
"""
import operator
import sys
from typing import Optional
import torch
from torch.fx import Graph, GraphModule, Node
from torch.fx.graph import map_arg
from torch.fx.proxy import Proxy
import sys
import torch
from torch.nn.utils import fuse_conv_bn_weights
import operator
from typing import Optional
# can be a
# module type, a builtin function, or a string to match target
def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps):
def _minmax_scale_zeropoint(
min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps
):
min_val = min(0.0, min_val)
max_val = max(0.0, max_val)
if max_val == min_val:
@ -28,9 +32,10 @@ def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.fin
zero_point = int(zero_point)
return scale, zero_point
class MinMaxObserver:
def __init__(self, quantizer, node):
self.min, self.max = float('inf'), float('-inf')
self.min, self.max = float("inf"), float("-inf")
self.all_tensors = True
def observe(self, node, env):
@ -44,6 +49,7 @@ class MinMaxObserver:
def scale_zeropoint(self):
return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255)
class NoObserver:
def __init__(self, quantizer, node):
pass
@ -51,11 +57,15 @@ class NoObserver:
def observe(self, node, env):
pass
_DEFAULT_QUANTIZATION_PATTERNS = {}
def register_pattern(pattern):
def insert(fn):
_DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert
@ -66,12 +76,19 @@ class Add(MinMaxObserver):
return NotImplemented
scale, zeropoint = self.scale_zeropoint()
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.add, load_arg(node.args), {'scale': scale, 'zero_point': zeropoint})
"call_function",
torch.ops.quantized.add,
load_arg(node.args),
{"scale": scale, "zero_point": zeropoint},
)
class Relu(NoObserver):
def quantize(self, quantizer, node, load_arg):
return torch.relu(load_arg(node.args[0])) # torch.relu works directly on quantized tensors?
return torch.relu(
load_arg(node.args[0])
) # torch.relu works directly on quantized tensors?
# these ops have quantized equivalents that do not need any extra information
@register_pattern(torch.nn.ReLU)
@ -82,15 +99,24 @@ class CopyNode(NoObserver):
def quantize(self, quantizer, node, load_arg):
return quantizer.quantized_graph.node_copy(node, load_arg)
class IdentityModule(torch.nn.Module):
def forward(self, x):
return x
# handle conv, maybe followed by bn, maybe followed by relu
@register_pattern(torch.nn.modules.conv.Conv2d)
@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d))
@register_pattern((torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d))
@register_pattern((torch.nn.ReLU, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)))
@register_pattern(
(torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)
)
@register_pattern(
(
torch.nn.ReLU,
(torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d),
)
)
class ConvNormRelu(MinMaxObserver):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
@ -112,21 +138,41 @@ class ConvNormRelu(MinMaxObserver):
if self.bn_node is not None:
weight, bias = fuse_conv_bn_weights(
weight, bias, self.bn.running_mean, self.bn.running_var,
self.bn.eps, self.bn.weight, self.bn.bias)
weight,
bias,
self.bn.running_mean,
self.bn.running_var,
self.bn.eps,
self.bn.weight,
self.bn.bias,
)
min_val, max_val = float(weight.min()), float(weight.max())
act_scale, act_zp = self.scale_zeropoint()
weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val)
qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8)
qweight = torch.quantize_per_tensor(
weight, weight_scale, weight_zp, torch.qint8
)
ctor = torch.ao.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.ao.nn.quantized.Conv2d
ctor = (
torch.ao.nn.intrinsic.quantized.ConvReLU2d
if self.relu_node is not None
else torch.ao.nn.quantized.Conv2d
)
qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size,
mod.stride, mod.padding, mod.dilation, mod.groups,
mod.bias is not None, mod.padding_mode)
qconv = ctor(
mod.in_channels,
mod.out_channels,
mod.kernel_size,
mod.stride,
mod.padding,
mod.dilation,
mod.groups,
mod.bias is not None,
mod.padding_mode,
)
qconv.set_weight_bias(qweight, bias)
qconv.scale = float(act_scale)
@ -139,24 +185,31 @@ class ConvNormRelu(MinMaxObserver):
# try to call it, so replace with something that does nothing.
setattr(quantizer.modules[parent_name], bn_name, IdentityModule())
return quantizer.quantized_graph.create_node('call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]),), {})
return quantizer.quantized_graph.create_node(
"call_module",
self.conv_node.target,
(load_arg(self.conv_node.args[0]),),
{},
)
# turn foo.bar -> ['foo', 'bar']
def _parent_name(target):
r = target.rsplit('.', 1)
r = target.rsplit(".", 1)
if len(r) == 1:
return '', r[0]
return "", r[0]
else:
return r[0], r[1]
class DefaultQuant(MinMaxObserver):
def quantize(self, input):
assert self.all_tensors
scale, zeropoint = self.scale_zeropoint()
return torch.quantize_per_tensor(Proxy(input), scale, zeropoint, torch.quint8).node
return torch.quantize_per_tensor(
Proxy(input), scale, zeropoint, torch.quint8
).node
def matches(modules, node, pattern, max_uses=sys.maxsize):
if isinstance(pattern, tuple):
@ -169,12 +222,12 @@ def matches(modules, node, pattern, max_uses=sys.maxsize):
return False
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
if node.op != "call_module":
return False
if not isinstance(modules[node.target], self_match):
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
if node.op != "call_function" or node.target is not self_match:
return False
elif node.target != self_match:
return False
@ -185,11 +238,16 @@ def matches(modules, node, pattern, max_uses=sys.maxsize):
if len(arg_matches) != len(node.args):
return False
return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
return all(
matches(modules, node, arg_match, max_uses=1)
for node, arg_match in zip(node.args, arg_matches)
)
class Quantizer:
def __init__(self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant):
def __init__(
self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant
):
self.root = mod
self.graph = mod.graph
self.quant_ctor = quant_ctor
@ -205,8 +263,6 @@ class Quantizer:
# initialize an quant_ctor object for each
self.quants = self._find_quants(quant_ctor)
def observe(self, args):
# most of this function is just an interpreter for the graph
# it would be possible to put this in some abstraction, but
@ -220,21 +276,23 @@ class Quantizer:
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
output_node : Optional[Node] = None
output_node: Optional[Node] = None
for node in self.graph.nodes:
if node.op == 'placeholder':
if node.op == "placeholder":
result = next(args_iter)
elif node.op == 'get_attr':
elif node.op == "get_attr":
result = self.state_dict[node.target]
elif node.op == 'call_function':
elif node.op == "call_function":
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'call_method':
elif node.op == "call_method":
self_obj, *args = load_arg(node.args)
kwargs = load_arg(node.kwargs)
result = getattr(self_obj, node.target)(*args, **kwargs)
elif node.op == 'call_module':
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
elif node.op == 'output':
elif node.op == "call_module":
result = self.modules[node.target](
*load_arg(node.args), **load_arg(node.kwargs)
)
elif node.op == "output":
return load_arg(node.args[0])
env[node.name] = result
@ -244,7 +302,7 @@ class Quantizer:
if node.name in self.quants:
self.quants[node.name].observe(node, env)
raise RuntimeError('Graph had no output node!')
raise RuntimeError("Graph had no output node!")
def quantize(self):
self.quantized_graph = Graph()
@ -268,17 +326,26 @@ class Quantizer:
return load_arg(n, quantized=False)
else:
return copy_recursive(n)
r = env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False))
r = env[node.name] = self.quantized_graph.node_copy(
node, lambda n: load_arg(n, quantized=False)
)
return r
for node in self.graph.nodes:
root_node, obj = self.matches.get(node.name, (None, None))
if root_node is None:
# not quantized just copy it
env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False))
env[node.name] = self.quantized_graph.node_copy(
node, lambda n: load_arg(n, quantized=False)
)
elif root_node is node:
r = obj.quantize(self, node, lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)))
r = obj.quantize(
self,
node,
lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)),
)
if r is NotImplemented:
# quantizer choose to to quantize the node take the entire match, and just copy it over
env[node.name] = copy_recursive(node)
@ -318,6 +385,7 @@ class Quantizer:
# say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
if n.name not in quants:
quants[n.name] = quant_ctor(self, n)
for node in self.graph.nodes:
if node.name in self.matches:
map_arg(node.args, visit_arg)

View File

@ -1,14 +1,19 @@
# Owner(s): ["oncall: fx"]
import itertools
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph_module import GraphModule
from torch.fx.passes.dialect.common.cse_pass import CSEPass
from torch.testing._internal.common_utils import (
TestCase, parametrize, instantiate_parametrized_tests, run_tests)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.dialect.common.cse_pass import CSEPass
from torch.fx.graph_module import GraphModule
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
import itertools
def FactoryFunctionCall(x, device):
y = torch.full(x.shape, 3, device=device)
@ -62,12 +67,14 @@ def MutationMetadata(x):
Passes = [CSEPass]
Test_Cases = [TakeList,
ReturnList,
Mutation,
MutationInput,
MutationMetadata,
MutationTorchTensorCall]
Test_Cases = [
TakeList,
ReturnList,
Mutation,
MutationInput,
MutationMetadata,
MutationTorchTensorCall,
]
Factory_Test_Cases = [FactoryFunctionCall, MutationFactory]
Devices = ["cpu"]
if torch.cuda.is_available():
@ -76,12 +83,14 @@ if torch.cuda.is_available():
def name_fn(common_pass, f, device):
"""Names parameterized test cases."""
return f'{type(common_pass()).__name__}_{f.__name__}_{device}'
return f"{type(common_pass()).__name__}_{f.__name__}_{device}"
@instantiate_parametrized_tests
class TestCommonPass(TestCase):
@parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn)
@parametrize(
"common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn
)
def test_correctness(self, common_pass, f, device):
inp = torch.randn(10, device=device)
@ -98,8 +107,11 @@ class TestCommonPass(TestCase):
self.assertEqual(result, expected)
@parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices), name_fn)
@parametrize(
"common_pass,f,device",
itertools.product(Passes, Factory_Test_Cases, Devices),
name_fn,
)
def test_correctness_factory(self, common_pass, f, device):
inp = torch.randn(10, device=device)
traced_m = make_fx(f)(inp, device)
@ -116,5 +128,5 @@ class TestCommonPass(TestCase):
self.assertEqual(result, expected)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -1,19 +1,19 @@
# Owner(s): ["oncall: fx"]
import torch
import random
from torch.testing._internal.common_utils import (
TestCase, run_tests)
import torch
from torch.fx import symbolic_trace
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops
from torch.fx import symbolic_trace
import random
from torch.testing._internal.common_utils import run_tests, TestCase
banned_ops = get_CSE_banned_ops()
P_default = CSEPass(banned_ops=banned_ops)
def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
"""
check if the CSE modified graph of ``f``
@ -47,34 +47,50 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
old_num_nodes = len(fx_g.graph.nodes)
new_num_nodes = len(new_graph.nodes)
assert (new_num_nodes < old_num_nodes) == modified, "modified should be True if the number of nodes decrease"
assert (
new_num_nodes < old_num_nodes
) == modified, "modified should be True if the number of nodes decrease"
if delta == -1:
self.assertTrue(old_num_nodes >= new_num_nodes, (
f"number of nodes increased {old_num_nodes}, {new_num_nodes}"))
self.assertTrue(
old_num_nodes >= new_num_nodes,
(f"number of nodes increased {old_num_nodes}, {new_num_nodes}"),
)
else:
self.assertTrue(old_num_nodes == new_num_nodes + delta, (
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"))
self.assertTrue(
old_num_nodes == new_num_nodes + delta,
(
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
),
)
# a second pass should not reduce more nodes
res = P(new_g)
pass_2_graph = res.graph_module.graph
pass_2_num_nodes = len(pass_2_graph.nodes)
self.assertTrue(pass_2_num_nodes == new_num_nodes, (
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"))
self.assertTrue(
pass_2_num_nodes == new_num_nodes,
(
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
),
)
# check correctness
if check_val:
true_result = fx_g(t)
our_result = new_g(t)
if true_result is None: # both return None
self.assertTrue(our_result is None, f"true result is None, CSE result is {our_result}")
self.assertTrue(
our_result is None, f"true result is None, CSE result is {our_result}"
)
else: # results returned are the same
self.assertTrue(torch.all(true_result == our_result), (
f"results are different {true_result}, {our_result}")) # check results are the same
self.assertTrue(
torch.all(true_result == our_result),
(f"results are different {true_result}, {our_result}"),
) # check results are the same
class TestCSEPass(TestCase):
def test_nochange(self):
def f(x):
a = x + 1
@ -82,16 +98,17 @@ class TestCSEPass(TestCase):
a = x
d = x + a
return b + d
t = torch.randn(2, 2)
check(self, f, t, 0)
def test_empty(self):
def f(x):
pass
t = torch.randn(2, 2)
check(self, f, t, 0)
def test_immutable_list_type(self):
def f(x):
a = x.sum(dim=1)
@ -99,6 +116,7 @@ class TestCSEPass(TestCase):
c = x.sum()
d = x.sum()
return a + b + c + d
t = torch.randn(2, 2)
check(self, f, t, 2)
@ -109,6 +127,7 @@ class TestCSEPass(TestCase):
c = x.sum(dim=1)
d = x.sum(dim=1)
return a + b + c + d
t = torch.randn(2, 2)
check(self, f, t, 2)
@ -119,6 +138,7 @@ class TestCSEPass(TestCase):
c = a + a
d = b + b
return c + d
t = torch.randn(2, 2)
check(self, f, t, 2)
@ -129,6 +149,7 @@ class TestCSEPass(TestCase):
c = a + a
d = b + b
return c + d
t = torch.randn(1)
check(self, f, t, 3)
@ -139,6 +160,7 @@ class TestCSEPass(TestCase):
c = x.sum(dim=1, keepdim=False)
d = x.sum(dim=1)
return a + b + c + d
t = torch.randn(2, 2)
check(self, f, t, 3)
@ -149,6 +171,7 @@ class TestCSEPass(TestCase):
c = x.sum(dim=1, keepdim=True)
d = x.sum(dim=1)
return a + b + c + d
t = torch.randn(2, 2)
check(self, f, t, 2)
@ -159,6 +182,7 @@ class TestCSEPass(TestCase):
c = x.sum()
d = x.sum()
return a + b + c + d
t = torch.randn(2, 2)
check(self, f, t, 3)
@ -167,6 +191,7 @@ class TestCSEPass(TestCase):
a = torch.cat((x, x))
b = torch.cat((x, x))
return a + b
t = torch.randn(2, 2)
check(self, f, t, 1)
@ -175,12 +200,14 @@ class TestCSEPass(TestCase):
a = torch.ones_like(x)
b = torch.ones_like(x)
return a + b
t = torch.randn(2, 2)
check(self, f, t, 1)
"""
Generate function with random ops and check if the result is the same
"""
def test_random(self):
def f(x):
vals = [x]
@ -201,6 +228,7 @@ class TestCSEPass(TestCase):
"""
Test that banned list ban ops as expected.
"""
def test_banned_list(self):
def f(x):
a = x + 1
@ -217,6 +245,7 @@ class TestCSEPass(TestCase):
a = torch.rand_like(x)
b = torch.rand_like(x)
return a + b
t = torch.randn(2, 2)
check(self, f, t, 0, check_val=False)
@ -225,9 +254,10 @@ class TestCSEPass(TestCase):
a = torch.randn(4)
b = torch.randn(4)
return a + b
t = torch.randn(2, 2)
check(self, f, t, 0, check_val=False)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: fx"]
from typing import Set, Type
import torch
import torch.fx

View File

@ -1,34 +1,42 @@
# Owner(s): ["module: fx"]
from __future__ import annotations # type: ignore[attr-defined]
import torch
from __future__ import annotations # type: ignore[attr-defined]
import typing
import torch
from torch.fx import symbolic_trace
class A:
def __call__(self, x: torch.Tensor):
return torch.add(x, x)
# No forward references
class M1(torch.nn.Module):
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
return a(x)
# Forward references
class M2(torch.nn.Module):
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
return a(x)
# Non-torch annotation with no internal forward references
class M3(torch.nn.Module):
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
return a(x[0])
# Non-torch annotation with internal forward references
class M4(torch.nn.Module):
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
return a(x[0])
x = torch.rand(2, 3)
ref = torch.add(x, x)

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: fx"]
import unittest
import torch
import torch.fx
@ -21,6 +22,7 @@ class MyModuleBase(torch.nn.Module):
def no_relu(self):
raise Exception("not implemented")
class MyModuleParamShape(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
@ -72,7 +74,6 @@ class MyModuleParamNumEl(MyModuleBase):
return self.param.numel() < 10 * 3
class MyModuleParamNElement(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
@ -82,43 +83,49 @@ class MyModuleParamNElement(MyModuleBase):
return self.param.nelement() < 10 * 3
class TestConstParamShapeInControlFlow(TestCase):
def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
"""
Verify one module only does a mm op while the other
performs both mm and relu ops in cascade
"""
x = torch.randn(10, 5)
torch.testing.assert_close(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
torch.testing.assert_close(
mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())
)
tracer = torch.fx.Tracer(param_shapes_constant=True)
traced_graph = tracer.trace(mm_only_mod)
# verify the graph module calculates the same result
graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
torch.testing.assert_close(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
torch.testing.assert_close(
graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())
)
# Make a new module with different parameter shape to go down the different
# code path
x = torch.randn(10, 15)
torch.testing.assert_close(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
torch.testing.assert_close(
relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
)
tracer2 = torch.fx.Tracer(param_shapes_constant=True)
traced_graph2 = tracer2.trace(relu_mod)
# verify the graph module calculates the same result
graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
torch.testing.assert_close(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
torch.testing.assert_close(
graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
)
graph1_node_targets = [n.target for n in traced_graph.nodes]
graph2_node_targets = [n.target for n in traced_graph2.nodes]
# the second graph has an exta relu function call node
assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
assert (
torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
)
def test_param_shape_const(self):
mymod = MyModuleParamShape(in_channels=5)
@ -151,5 +158,5 @@ class TestConstParamShapeInControlFlow(TestCase):
self.verify_mm_relu_mods(mymod, mymod2)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,7 +1,7 @@
# Owner(s): ["module: fx"]
from collections import defaultdict
from typing import List, Tuple, Dict
from typing import Dict, List, Tuple
import torch
from torch.fx.passes.split_utils import split_by_tags
@ -40,6 +40,7 @@ class TestFXSplit(TestCase):
self.assertIn("name", node.meta)
self.assertEqual(node.meta["name"], node.name)
class TestSplitByTags(TestCase):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
@ -151,6 +152,7 @@ class TestSplitByTags(TestCase):
f"{orig_to_split_fqn_mapping=}",
)
class TestSplitOutputType(TestCase):
class TestModule(torch.nn.Module):
def __init__(self):

View File

@ -1,18 +1,22 @@
# Owner(s): ["module: fx"]
import unittest
import torch
from torch.fx import symbolic_trace
from torch.fx.experimental.unify_refinements import infer_symbolic_types
from torch.fx.experimental.refinement_types import Equality
from torch.fx.tensor_type import TensorType, Dyn, is_consistent, is_more_precise
from torch.fx.annotate import annotate
from torch.fx.experimental.graph_gradual_typechecker import GraphTypeChecker, broadcast_types, Refine
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx import GraphModule
from torch.fx.passes.shape_prop import ShapeProp
from torch.testing._internal.common_utils import TestCase
import sympy
import torch
from torch.fx import GraphModule, symbolic_trace
from torch.fx.annotate import annotate
from torch.fx.experimental.graph_gradual_typechecker import (
broadcast_types,
GraphTypeChecker,
Refine,
)
from torch.fx.experimental.refinement_types import Equality
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx.experimental.unify_refinements import infer_symbolic_types
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType
from torch.testing._internal.common_utils import TestCase
try:
@ -23,24 +27,33 @@ except ImportError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
return torch.nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)
class AnnotationsTest(TestCase):
def test_annotations(self):
"""
Test type annotations in the forward function.
The annotation should appear in the n.graph
where n is the corresponding node in the resulting graph.
"""
class M(torch.nn.Module):
def forward(self,
x: TensorType((1, 2, 3, Dyn)),
y: Dyn,
z: TensorType[Dyn, 3, Dyn]):
def forward(
self, x: TensorType((1, 2, 3, Dyn)), y: Dyn, z: TensorType[Dyn, 3, Dyn]
):
return torch.add(x, y) + z
module = M()
@ -50,20 +63,19 @@ class AnnotationsTest(TestCase):
expected_iter = iter(expected_ph_types)
for n in symbolic_traced.graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == next(expected_iter)
def test_annotate(self):
class M(torch.nn.Module):
def forward(self, x):
y = annotate(x, TensorType((1, 2, 3, Dyn)))
return torch.add(x, y)
module = M()
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
for n in symbolic_traced.graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == TensorType((1, 2, 3, Dyn))
def test_consistency(self):
@ -83,7 +95,9 @@ class AnnotationsTest(TestCase):
self.assertTrue(is_more_precise(TensorType((1, 2, 3)), TensorType((1, Dyn, 3))))
self.assertTrue(is_more_precise(int, Dyn))
self.assertTrue(is_more_precise(int, int))
self.assertFalse(is_more_precise(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5))))
self.assertFalse(
is_more_precise(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5)))
)
self.assertFalse(is_more_precise(TensorType((1, 2, 3)), int))
def test_broadcasting1(self):
@ -94,7 +108,10 @@ class AnnotationsTest(TestCase):
t5 = TensorType((4, 4, 4))
# todo switch all code to use list instead of tuple
t6 = TensorType([1])
assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, 4)))
assert broadcast_types(t1, t2) == (
TensorType((1, 2, 3, 4)),
TensorType((1, 2, 3, 4)),
)
assert broadcast_types(t3, t4) == (t4, t4)
assert broadcast_types(t5, t6) == (t5, t5)
@ -102,46 +119,58 @@ class AnnotationsTest(TestCase):
t1 = TensorType((2, 3, 4))
t2 = TensorType((1, 2, 1, 4))
assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, 4)))
assert broadcast_types(t1, t2) == (
TensorType((1, 2, 3, 4)),
TensorType((1, 2, 3, 4)),
)
def test_broadcasting3(self):
t1 = TensorType((1, 2, 3, Dyn))
t2 = TensorType((2, 3, 4))
assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, Dyn)), TensorType((1, 2, 3, 4)))
assert broadcast_types(t1, t2) == (
TensorType((1, 2, 3, Dyn)),
TensorType((1, 2, 3, 4)),
)
class TypeCheckerTest(TestCase):
def test_type_check_add_with_broadcast(self):
class M(torch.nn.Module):
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))):
return torch.add(x, y)
module = M()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
tc.type_check()
expected_ph_types = [TensorType((1, 2, 3, Dyn)),
TensorType((2, 3, 4)),
TensorType((1, 2, 3, Dyn)),
TensorType((1, 2, 3, Dyn))]
expected_ph_types = [
TensorType((1, 2, 3, Dyn)),
TensorType((2, 3, 4)),
TensorType((1, 2, 3, Dyn)),
TensorType((1, 2, 3, Dyn)),
]
expected_iter = iter(expected_ph_types)
for n in symbolic_traced.graph.nodes:
if n.op == 'call_function':
assert n.meta['broadcast']
if n.op == "call_function":
assert n.meta["broadcast"]
assert n.type == next(expected_iter)
def test_type_check_add_with_scalar(self):
class M(torch.nn.Module):
def forward(self, x: int, y: TensorType((2, 3, 4))):
return torch.add(x, y)
module = M()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
tc.type_check()
expected_ph_types = [int,
TensorType((2, 3, 4)),
TensorType((2, 3, 4)),
TensorType((2, 3, 4))]
expected_ph_types = [
int,
TensorType((2, 3, 4)),
TensorType((2, 3, 4)),
TensorType((2, 3, 4)),
]
expected_iter = iter(expected_ph_types)
for n in symbolic_traced.graph.nodes:
@ -151,6 +180,7 @@ class TypeCheckerTest(TestCase):
class M(torch.nn.Module):
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((1, 2, 3))):
return torch.add(x, y)
module = M()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
@ -161,6 +191,7 @@ class TypeCheckerTest(TestCase):
class M(torch.nn.Module):
def forward(self, x: TensorType((1, 2, Dyn)), y: TensorType((1, 2, 3))):
return torch.add(x, y)
module = M()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
@ -170,9 +201,9 @@ class TypeCheckerTest(TestCase):
expected_iter = iter(expected_ph_types)
for n in symbolic_traced.graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == next(expected_iter)
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType((1, 2, Dyn))
def test_type_check_reshape_true(self):
@ -186,13 +217,13 @@ class TypeCheckerTest(TestCase):
self.assertTrue(tc.type_check())
for n in symbolic_traced.graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == TensorType((1, 6))
if n.op == 'call_function':
if n.op == "call_function":
assert n.type == TensorType((1, 2, 3))
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType((1, 2, 3))
def test_type_check_reshape_false(self):
@ -244,16 +275,16 @@ class TypeCheckerTest(TestCase):
return torch.transpose(x, 0, 1)
module = M()
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
self.assertTrue(tc.type_check())
for n in symbolic_traced.graph.nodes:
if n.op == 'call_function':
if n.op == "call_function":
assert n.type == TensorType([2, 1, 3, 5])
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType([2, 1, 3, 5])
if n.op == 'x':
if n.op == "x":
assert n.placeholder == TensorType([1, 2, 3, 5])
def test_type_check_transpose_False(self):
@ -269,7 +300,6 @@ class TypeCheckerTest(TestCase):
def test_type_check_batch_norm_2D(self):
class BasicBlock(torch.nn.Module):
def __init__(self, inplanes, planes):
super().__init__()
norm_layer = torch.nn.BatchNorm2d
@ -289,18 +319,17 @@ class TypeCheckerTest(TestCase):
tc.type_check()
for n in graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == TensorType((2, 2, 5, 4))
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType((2, 2, 5, 4))
if n.op == 'call_module':
if n.op == "call_module":
assert n.type == TensorType((2, 2, 5, 4))
if n.op == 'call_function':
if n.op == "call_function":
assert n.type == TensorType((2, 2, 5, 4))
def test_type_check_batch_norm_2D_false(self):
class BasicBlock(torch.nn.Module):
def __init__(self, inplanes, planes):
super().__init__()
norm_layer = torch.nn.BatchNorm2d
@ -322,7 +351,6 @@ class TypeCheckerTest(TestCase):
def test_type_check_batch_norm_2D_broadcast(self):
class BasicBlock(torch.nn.Module):
def __init__(self, inplanes, planes):
super().__init__()
norm_layer = torch.nn.BatchNorm2d
@ -341,13 +369,13 @@ class TypeCheckerTest(TestCase):
tc = GraphTypeChecker({}, traced)
tc.type_check()
for n in graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'call_function':
if n.op == "call_function":
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'call_module':
if n.op == "call_module":
assert n.type == TensorType((2, 2, Dyn, 4))
B = BasicBlock(1, 1)
@ -379,13 +407,13 @@ class TypeCheckerTest(TestCase):
tc = GraphTypeChecker({}, traced)
tc.type_check()
for n in graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'call_function':
if n.op == "call_function":
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
if n.op == 'call_module':
if n.op == "call_module":
assert n.type == TensorType((2, 2, Dyn, 4))
def test_type_check_conv2D_2(self):
@ -412,13 +440,13 @@ class TypeCheckerTest(TestCase):
tc.type_check()
t = TensorType((5, 2, 3, 4))
for n in graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == t
if n.op == 'call_function':
if n.op == "call_function":
assert n.type == t
if n.op == 'output':
if n.op == "output":
assert torch.Size(n.type.__args__) == b.shape
if n.op == 'call_module':
if n.op == "call_module":
assert n.type == t
B = BasicBlock(1, 2)
@ -430,12 +458,27 @@ class TypeCheckerTest(TestCase):
tc.type_check()
def test_type_check_conv2D_2_fully_static(self):
annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
(10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)]
input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
(10, 15, 13, 14), (1, 2, 2, 3)]
intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6), (10, 15, Dyn, 5),
(10, 15, 7, 7), (1, Dyn, Dyn, Dyn)]
annotation_list = [
(1, 2, 3, 5),
(2, 5, 6, 9),
(10, 15, 13, 14),
(10, Dyn, 13, 14),
(Dyn, Dyn, Dyn, 3),
]
input_list = [
(1, 2, 3, 5),
(2, 5, 6, 9),
(10, 15, 13, 14),
(10, 15, 13, 14),
(1, 2, 2, 3),
]
intermediate_types = [
(1, Dyn, Dyn, 7),
(2, Dyn, 4, 6),
(10, 15, Dyn, 5),
(10, 15, 7, 7),
(1, Dyn, Dyn, Dyn),
]
in_planes_list = [2, 5, 15, 15, 2]
stride_list = [1, 2, 3, 2, 2]
out_planes_list = [2, 5, 15, 15, 2]
@ -443,7 +486,13 @@ class TypeCheckerTest(TestCase):
dilation_list = [1, 2, 3, 3, 3]
padding_list = [1, 2, 3, 3, 3]
kernel_size_list = [1, 2, 3, 3, 3]
output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, 2, Dyn, Dyn)]
output_types = [
(1, 2, Dyn, 7),
(2, 5, 4, 6),
(10, 15, Dyn, 5),
(10, 15, 7, 7),
(1, 2, Dyn, Dyn),
]
for i in range(5):
annotation = annotation_list[i]
@ -458,24 +507,42 @@ class TypeCheckerTest(TestCase):
intermediate_type = intermediate_types[i]
class BasicBlock(torch.nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation):
def __init__(
self,
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups,
dilation,
):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, groups=groups, bias=False, dilation=dilation)
self.conv1 = torch.nn.Conv2d(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
dilation=dilation,
)
def forward(self, x):
out = self.conv1(x)
return out
B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation)
B = BasicBlock(
in_planes, out_planes, kernel_size, stride, padding, groups, dilation
)
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(B)
traced = GraphModule(ast_rewriter.root, graph, "gm")
# annotate our argument
for n in graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
n.type = TensorType(annotation)
b = B.forward(torch.rand(input))
@ -483,36 +550,54 @@ class TypeCheckerTest(TestCase):
tc.type_check()
for n in graph.nodes:
if n.op == 'output':
if n.op == "output":
assert is_consistent(n.type, TensorType(b.size()))
# test with intermediate annotations
class BasicBlock(torch.nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation):
def __init__(
self,
in_planes,
out_planes,
kernel_size,
stride,
padding,
groups,
dilation,
):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, groups=groups, bias=False, dilation=dilation)
self.conv1 = torch.nn.Conv2d(
in_channels=in_planes,
out_channels=out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
dilation=dilation,
)
def forward(self, x):
out = self.conv1(x)
return out
B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation)
B = BasicBlock(
in_planes, out_planes, kernel_size, stride, padding, groups, dilation
)
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(B)
traced = GraphModule(ast_rewriter.root, graph, "gm")
# populate our intermediate notes
for n in traced.graph.nodes:
if n.op == 'call_module':
if n.op == "call_module":
n.type = TensorType(intermediate_type)
tc = GraphTypeChecker({}, traced)
tc.type_check()
for n in traced.graph.nodes:
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType(output_types[i])
assert is_consistent(n.type, TensorType(b.size()))
@ -520,14 +605,26 @@ class TypeCheckerTest(TestCase):
class BasicBlock(torch.nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1):
def __init__(
self,
inplanes,
planes,
stride=1,
downsample=None,
groups=1,
base_width=64,
dilation=1,
):
super().__init__()
norm_layer = torch.nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
raise ValueError(
"BasicBlock only supports groups=1 and base_width=64"
)
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock"
)
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
@ -565,12 +662,14 @@ class TypeCheckerTest(TestCase):
tc.type_check()
for n in traced.graph.nodes:
if n.target == 'output':
if n.target == "output":
assert isinstance(n.type, TensorType)
assert torch.Size(n.type.__args__) == B.forward(torch.rand(2, 2, 4, 5)).size()
assert (
torch.Size(n.type.__args__)
== B.forward(torch.rand(2, 2, 4, 5)).size()
)
def test_type_check_conv2D_maxpool2d_flatten(self):
class BasicBlock(torch.nn.Module):
def __init__(self):
super().__init__()
@ -581,7 +680,7 @@ class TypeCheckerTest(TestCase):
self.fc1 = torch.nn.Linear(5, 120)
self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))
def forward(self, x : TensorType((4, 3, 32, 32))):
def forward(self, x: TensorType((4, 3, 32, 32))):
out = self.conv1(x)
out = self.pool(out)
out = self.conv2(out)
@ -598,10 +697,17 @@ class TypeCheckerTest(TestCase):
tc = GraphTypeChecker({}, traced)
tc.type_check()
expected_ph_types = [TensorType((4, 3, 32, 32)), TensorType((4, 6, 28, 28)),
TensorType((4, 6, 14, 14)), TensorType((4, 16, 10, 10)),
TensorType((4, 16, 5, 5)), TensorType((4, 16, 5, 120)),
TensorType((4, 16, 6, 7)), TensorType((4, 672)), TensorType((4, 672))]
expected_ph_types = [
TensorType((4, 3, 32, 32)),
TensorType((4, 6, 28, 28)),
TensorType((4, 6, 14, 14)),
TensorType((4, 16, 10, 10)),
TensorType((4, 16, 5, 5)),
TensorType((4, 16, 5, 120)),
TensorType((4, 16, 6, 7)),
TensorType((4, 672)),
TensorType((4, 672)),
]
expected_iter = iter(expected_ph_types)
traced.graph.eliminate_dead_code()
@ -619,10 +725,9 @@ class TypeCheckerTest(TestCase):
tc = GraphTypeChecker({}, symbolic_traced)
tc.type_check()
for n in symbolic_traced.graph.nodes:
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType((1, 6, 5, Dyn))
def test_type_check_flatten_2(self):
class M(torch.nn.Module):
def forward(self, x: TensorType((1, Dyn, 3, 5, Dyn))):
@ -633,7 +738,7 @@ class TypeCheckerTest(TestCase):
tc = GraphTypeChecker({}, symbolic_traced)
tc.type_check()
for n in symbolic_traced.graph.nodes:
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType((1, Dyn, 5, Dyn))
def test_type_check_flatten3(self):
@ -646,7 +751,7 @@ class TypeCheckerTest(TestCase):
tc = GraphTypeChecker({}, symbolic_traced)
tc.type_check()
for n in symbolic_traced.graph.nodes:
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType((2, 60))
r = Refine(symbolic_traced)
r.refine()
@ -654,13 +759,12 @@ class TypeCheckerTest(TestCase):
assert c == [Equality(2, 2)]
def test_type_typechecl_maxpool2d_3dinput(self):
class BasicBlock(torch.nn.Module):
def __init__(self):
super().__init__()
self.pool = torch.nn.MaxPool2d(5, 8)
def forward(self, x : TensorType((64, 8, 8))):
def forward(self, x: TensorType((64, 8, 8))):
out = self.pool(x)
return out
@ -672,21 +776,42 @@ class TypeCheckerTest(TestCase):
tc.type_check()
for n in traced.graph.nodes:
if n.target == 'output':
if n.target == "output":
assert n.type == TensorType((64, 1, 1))
def test_type_maxpool2d_fully_static(self):
annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
(10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)]
input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
(10, 15, 13, 14), (2, 2, 10, 10)]
intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4), (10, 15, Dyn, 2),
(10, 15, 2, 3), (2, Dyn, Dyn, Dyn)]
annotation_list = [
(Dyn, Dyn, 3, 5),
(2, 5, 6, 9),
(10, 15, 13, 14),
(10, Dyn, 13, 14),
(Dyn, Dyn, Dyn, 10),
]
input_list = [
(1, 2, 3, 5),
(2, 5, 6, 9),
(10, 15, 13, 14),
(10, 15, 13, 14),
(2, 2, 10, 10),
]
intermediate_types = [
(1, 2, Dyn, Dyn),
(2, Dyn, 2, 4),
(10, 15, Dyn, 2),
(10, 15, 2, 3),
(2, Dyn, Dyn, Dyn),
]
stride_list = [1, 2, 3, 2, 1]
dilation_list = [1, 2, 3, 3, 2]
padding_list = [1, 2, 3, 3, 1]
kernel_size_list = [2, 4, 6, 6, 3]
output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2), (10, 15, 2, 3), (2, Dyn, Dyn, 8)]
output_types = [
(1, 2, 4, 6),
(2, 5, 2, 4),
(10, 15, 2, 2),
(10, 15, 2, 3),
(2, Dyn, Dyn, 8),
]
for i in range(5):
annotation = annotation_list[i]
@ -700,9 +825,14 @@ class TypeCheckerTest(TestCase):
class BasicBlock(torch.nn.Module):
def __init__(self, kernel_size, stride, padding, dilation):
super().__init__()
self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride,
padding=padding, dilation=dilation,
return_indices=False, ceil_mode=False)
self.pool = torch.nn.MaxPool2d(
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=False,
ceil_mode=False,
)
def forward(self, x):
out = self.pool(x)
@ -715,7 +845,7 @@ class TypeCheckerTest(TestCase):
# annotate our argument
for n in graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
n.type = TensorType(annotation)
b = B.forward(torch.rand(input))
@ -723,16 +853,21 @@ class TypeCheckerTest(TestCase):
tc.type_check()
for n in graph.nodes:
if n.op == 'output':
if n.op == "output":
assert is_consistent(n.type, TensorType(b.size()))
# test with intermediate annotations
class BasicBlock(torch.nn.Module):
def __init__(self, kernel_size, stride, padding, dilation):
super().__init__()
self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride,
padding=padding, dilation=dilation,
return_indices=False, ceil_mode=False)
self.pool = torch.nn.MaxPool2d(
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=False,
ceil_mode=False,
)
def forward(self, x):
out = self.pool(x)
@ -745,30 +880,45 @@ class TypeCheckerTest(TestCase):
# annotate our argument
for n in graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
n.type = TensorType(annotation)
# populate our intermediate notes
for n in traced.graph.nodes:
if n.op == 'call_module':
if n.op == "call_module":
n.type = TensorType(intermediate_type)
tc = GraphTypeChecker({}, traced)
tc.type_check()
for n in traced.graph.nodes:
if n.op == 'output':
if n.op == "output":
assert n.type == TensorType(output_types[i])
assert is_consistent(n.type, TensorType(b.size()))
def test_flatten_fully_static(self):
annotation_list = [Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)),
TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10))]
input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
(10, 15, 13, 14), (2, 2, 10, 10)]
annotation_list = [
Dyn,
TensorType((2, 5, 6, 9)),
TensorType((10, 15, 13, 14)),
TensorType((10, Dyn, 13, 14)),
TensorType((Dyn, Dyn, Dyn, 10)),
]
input_list = [
(1, 2, 3, 5),
(2, 5, 6, 9),
(10, 15, 13, 14),
(10, 15, 13, 14),
(2, 2, 10, 10),
]
intermediate_list = [Dyn, (2, 5, 6, 9), (10, 15, 13, 14),
(10, 15, 13, 14), (2, 2, 10, 10)]
intermediate_list = [
Dyn,
(2, 5, 6, 9),
(10, 15, 13, 14),
(10, 15, 13, 14),
(2, 2, 10, 10),
]
start_dim = [1, 2, 1, 2, 0]
end_dim = [1, 3, 3, 3, -2]
@ -795,7 +945,7 @@ class TypeCheckerTest(TestCase):
# annotate our argument
for n in graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
n.type = annotation
b = B.forward(torch.rand(input))
@ -803,7 +953,7 @@ class TypeCheckerTest(TestCase):
tc.type_check()
for n in graph.nodes:
if n.op == 'output':
if n.op == "output":
assert is_consistent(n.type, TensorType(b.size()))
@skipIfNoTorchVision
@ -825,36 +975,34 @@ class TypeCheckerTest(TestCase):
gm_run.graph.eliminate_dead_code()
# here we are checking for consistency with fully dynamic nodes
for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes):
assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape))
assert is_consistent(n1.type, TensorType(n2.meta["tensor_meta"].shape))
# here we give the same input as to runtime
gm_static_with_types = symbolic_trace(resnet50())
# we initialize our placeholder
for n in gm_static_with_types.graph.nodes:
if n.op == 'placeholder':
if n.op == "placeholder":
n.type = TensorType((1, 3, 224, 224))
g = GraphTypeChecker({}, gm_static_with_types)
g.type_check()
for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes):
assert n1.type == TensorType(n2.meta['tensor_meta'].shape)
assert n1.type == TensorType(n2.meta["tensor_meta"].shape)
# apply shape inference to graph and check
# that the batch size is equal across all layers
infer_symbolic_types(gm_static)
batch_sizes = set()
gm_static.graph.eliminate_dead_code()
for n in gm_static.graph.nodes:
assert isinstance(n.type, TensorType)
batch_sizes.add(n.type.__args__[0])
assert (len(batch_sizes) == 1)
assert len(batch_sizes) == 1
def test_type_check_batch_norm_symbolic(self):
class BasicBlock(torch.nn.Module):
def __init__(self, inplanes, planes):
super().__init__()
norm_layer = torch.nn.BatchNorm2d
@ -875,10 +1023,14 @@ class TypeCheckerTest(TestCase):
infer_symbolic_types(traced)
my_types = iter([TensorType[(2, 2, sympy.symbols('~7'), 4)],
TensorType[(2, 2, sympy.symbols('~7'), 4)],
TensorType[(2, 2, sympy.symbols('~7'), 4)],
TensorType[(2, 2, sympy.symbols('~7'), 4)]])
my_types = iter(
[
TensorType[(2, 2, sympy.symbols("~7"), 4)],
TensorType[(2, 2, sympy.symbols("~7"), 4)],
TensorType[(2, 2, sympy.symbols("~7"), 4)],
TensorType[(2, 2, sympy.symbols("~7"), 4)],
]
)
for n in graph.nodes:
assert n.type == next(my_types)
@ -887,6 +1039,7 @@ class TypeCheckerTest(TestCase):
class M(torch.nn.Module):
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))):
return torch.add(x, y)
module = M()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
@ -901,13 +1054,14 @@ class TypeCheckerTest(TestCase):
infer_symbolic_types(symbolic_traced)
expected_ph_types = [TensorType((1, 2, 3, sympy.symbols('~0'))),
TensorType((2, 3, 4)),
TensorType((1, 2, 3, sympy.symbols('~1'))),
TensorType((1, 2, 3, sympy.symbols('~1')))]
expected_ph_types = [
TensorType((1, 2, 3, sympy.symbols("~0"))),
TensorType((2, 3, 4)),
TensorType((1, 2, 3, sympy.symbols("~1"))),
TensorType((1, 2, 3, sympy.symbols("~1"))),
]
expected_iter = iter(expected_ph_types)
for n in symbolic_traced.graph.nodes:
assert n.type == next(expected_iter)
@ -915,6 +1069,7 @@ class TypeCheckerTest(TestCase):
class M(torch.nn.Module):
def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))):
return torch.add(x, y)
module = M()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
tc = GraphTypeChecker({}, symbolic_traced)
@ -923,10 +1078,12 @@ class TypeCheckerTest(TestCase):
r = Refine(symbolic_traced)
r.refine()
expected_ph_types = [TensorType((1, 2)),
TensorType((sympy.symbols('~1'), 2)),
TensorType((sympy.symbols('~1'), 2)),
TensorType((sympy.symbols('~1'), 2))]
expected_ph_types = [
TensorType((1, 2)),
TensorType((sympy.symbols("~1"), 2)),
TensorType((sympy.symbols("~1"), 2)),
TensorType((sympy.symbols("~1"), 2)),
]
expected_iter = iter(expected_ph_types)
for n in symbolic_traced.graph.nodes:
@ -955,12 +1112,11 @@ class TypeCheckerTest(TestCase):
infer_symbolic_types(traced)
for n in traced.graph.nodes:
if n.op == 'call_module':
if n.op == "call_module":
assert isinstance(n.type.__args__[2], sympy.floor)
assert isinstance(n.type.__args__[3], sympy.floor)
def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self):
class BasicBlock(torch.nn.Module):
def __init__(self):
super().__init__()
@ -971,7 +1127,7 @@ class TypeCheckerTest(TestCase):
self.fc1 = torch.nn.Linear(5, 120)
self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))
def forward(self, x : TensorType((4, 3, Dyn, Dyn))):
def forward(self, x: TensorType((4, 3, Dyn, Dyn))):
out = self.conv1(x)
out = self.pool(out)
out = self.conv2(out)
@ -989,13 +1145,26 @@ class TypeCheckerTest(TestCase):
infer_symbolic_types(traced)
for n in traced.graph.nodes:
if n.target == 'conv1':
assert n.type == TensorType((4, 6, sympy.floor(sympy.symbols('~0') - 4),
sympy.floor(sympy.symbols('~1') - 4)))
if n.target == "conv1":
assert n.type == TensorType(
(
4,
6,
sympy.floor(sympy.symbols("~0") - 4),
sympy.floor(sympy.symbols("~1") - 4),
)
)
elif n.target == 'conv2':
assert n.type == TensorType((4, 16, sympy.floor(sympy.symbols('~4') - 4),
sympy.floor(sympy.symbols('~5') - 4)))
elif n.target == "conv2":
assert n.type == TensorType(
(
4,
16,
sympy.floor(sympy.symbols("~4") - 4),
sympy.floor(sympy.symbols("~5") - 4),
)
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -69,10 +69,12 @@ class TestMatcher(JitTestCase):
def test_subgraph_matcher_with_list(self):
def original(x, y):
return torch.ops.aten.view(x, [5, y.shape[0]])
original_graph = torch.fx.symbolic_trace(original).graph
def pattern(x, y, z):
return torch.ops.aten.view(x, [z, y.shape[0]])
pattern_graph = torch.fx.symbolic_trace(pattern).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
@ -81,11 +83,17 @@ class TestMatcher(JitTestCase):
def test_subgraph_matcher_with_list_bad(self):
def original(x, y):
return torch.ops.aten._reshape_alias_copy.default(x, [1, y.shape[0]], [y.shape[1], y.shape[1]])
return torch.ops.aten._reshape_alias_copy.default(
x, [1, y.shape[0]], [y.shape[1], y.shape[1]]
)
original_graph = torch.fx.symbolic_trace(original).graph
def pattern(x, y, b):
return torch.ops.aten._reshape_alias_copy.default(x, [b, y.shape[0], y.shape[1]], [y.shape[1]])
return torch.ops.aten._reshape_alias_copy.default(
x, [b, y.shape[0], y.shape[1]], [y.shape[1]]
)
pattern_graph = torch.fx.symbolic_trace(pattern).graph
subgraph_matcher = SubgraphMatcher(pattern_graph)
@ -101,6 +109,7 @@ class TestMatcher(JitTestCase):
def pattern(x):
return x + 2
pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph
pattern_graph.eliminate_dead_code()
@ -116,7 +125,10 @@ class TestMatcher(JitTestCase):
inputs = (torch.randn(20, 16, 50, 32),)
def maxpool(x, kernel_size, stride, padding, dilation):
return torch.ops.aten.max_pool2d_with_indices.default(x, kernel_size, stride, padding, dilation)
return torch.ops.aten.max_pool2d_with_indices.default(
x, kernel_size, stride, padding, dilation
)
maxpool_graph = torch.fx.symbolic_trace(maxpool).graph
maxpool_matcher = SubgraphMatcher(maxpool_graph)
@ -144,7 +156,9 @@ class TestMatcher(JitTestCase):
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
def test_split_to_graph_and_name_node_map(self):
"""Testing the internal helper function for splitting the pattern graph"""
from torch.fx.passes.utils.matcher_with_name_node_map_utils import _split_to_graph_and_name_node_map
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
_split_to_graph_and_name_node_map,
)
def pattern(x, weight):
conv = F.conv2d(x, weight)
@ -153,6 +167,7 @@ class TestMatcher(JitTestCase):
return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
from torch._export import capture_pre_autograd_graph
example_inputs = (
torch.randn(1, 3, 3, 3) * 10,
torch.randn(3, 3, 3, 3),
@ -166,8 +181,7 @@ class TestMatcher(JitTestCase):
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
def test_matcher_with_name_node_map_function(self):
"""Testing SubgraphMatcherWithNameNodeMap with function pattern
"""
"""Testing SubgraphMatcherWithNameNodeMap with function pattern"""
def target_graph(x, weight):
x = x * 2
@ -184,13 +198,16 @@ class TestMatcher(JitTestCase):
return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
from torch._export import capture_pre_autograd_graph
example_inputs = (
torch.randn(1, 3, 3, 3) * 10,
torch.randn(3, 3, 3, 3),
)
pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs)
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
target_gm = capture_pre_autograd_graph(WrapperModule(target_graph), example_inputs)
target_gm = capture_pre_autograd_graph(
WrapperModule(target_graph), example_inputs
)
internal_matches = matcher.match(target_gm.graph)
for internal_match in internal_matches:
name_node_map = internal_match.name_node_map
@ -200,12 +217,14 @@ class TestMatcher(JitTestCase):
# check if we correctly annotated the target graph module
for n in target_gm.graph.nodes:
if n == name_node_map["conv"]:
assert "custom_annotation" in n.meta and n.meta["custom_annotation"] == "annotation"
assert (
"custom_annotation" in n.meta
and n.meta["custom_annotation"] == "annotation"
)
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
def test_matcher_with_name_node_map_module(self):
"""Testing SubgraphMatcherWithNameNodeMap with module pattern
"""
"""Testing SubgraphMatcherWithNameNodeMap with module pattern"""
class M(torch.nn.Module):
def __init__(self):
@ -215,7 +234,6 @@ class TestMatcher(JitTestCase):
def forward(self, x):
return self.linear(x)
class Pattern(torch.nn.Module):
def __init__(self):
super().__init__()
@ -228,9 +246,8 @@ class TestMatcher(JitTestCase):
return linear, {"linear": linear, "x": x}
from torch._export import capture_pre_autograd_graph
example_inputs = (
torch.randn(3, 5),
)
example_inputs = (torch.randn(3, 5),)
pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs)
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
target_gm = capture_pre_autograd_graph(M(), example_inputs)
@ -243,7 +260,11 @@ class TestMatcher(JitTestCase):
# check if we correctly annotated the target graph module
for n in target_gm.graph.nodes:
if n == name_node_map["linear"]:
assert "custom_annotation" in n.meta and n.meta["custom_annotation"] == "annotation"
assert (
"custom_annotation" in n.meta
and n.meta["custom_annotation"] == "annotation"
)
if __name__ == "__main__":
run_tests()

View File

@ -9,9 +9,13 @@ import torch
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch._dynamo.eval_frame import is_dynamo_supported
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions, check_subgraphs_connected
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
get_source_partitions,
)
from torch.testing._internal.jit_utils import JitTestCase
class TestSourceMatcher(JitTestCase):
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_linear_relu_linear(self):
@ -33,15 +37,32 @@ class TestSourceMatcher(JitTestCase):
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU])
module_partitions = get_source_partitions(
gm.graph, [torch.nn.Linear, torch.nn.ReLU]
)
self.assertEqual(len(module_partitions), 2)
self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][0], module_partitions[torch.nn.ReLU][0]))
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Linear][1], module_partitions[torch.nn.ReLU][0]))
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][2], module_partitions[torch.nn.ReLU][0]))
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.Linear][0],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertTrue(
check_subgraphs_connected(
module_partitions[torch.nn.Linear][1],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.Linear][2],
module_partitions[torch.nn.ReLU][0],
)
)
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_conv_relu_maxpool(self):
@ -69,21 +90,50 @@ class TestSourceMatcher(JitTestCase):
return self.maxpool(self.relu(z))
inputs = (torch.randn(1, 3, 256, 256),)
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(*inputs)
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(
*inputs
)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d])
module_partitions = get_source_partitions(
gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d]
)
self.assertEqual(len(module_partitions), 3)
self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][0], module_partitions[torch.nn.ReLU][0]))
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][1], module_partitions[torch.nn.ReLU][0]))
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][2], module_partitions[torch.nn.ReLU][0]))
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.MaxPool2d][0], module_partitions[torch.nn.ReLU][0]))
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.ReLU][0], module_partitions[torch.nn.MaxPool2d][0]))
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.Conv2d][0],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.Conv2d][1],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertTrue(
check_subgraphs_connected(
module_partitions[torch.nn.Conv2d][2],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertFalse(
check_subgraphs_connected(
module_partitions[torch.nn.MaxPool2d][0],
module_partitions[torch.nn.ReLU][0],
)
)
self.assertTrue(
check_subgraphs_connected(
module_partitions[torch.nn.ReLU][0],
module_partitions[torch.nn.MaxPool2d][0],
)
)
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
def test_module_partitioner_functional_conv_relu_conv(self):
@ -96,7 +146,15 @@ class TestSourceMatcher(JitTestCase):
self.groups = 1
def forward(self, x, weight, bias):
return torch.nn.functional.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(
x,
weight,
bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
class M(torch.nn.Module):
def __init__(self):
@ -114,7 +172,9 @@ class TestSourceMatcher(JitTestCase):
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d])
module_partitions = get_source_partitions(
gm.graph, [torch.nn.functional.conv2d]
)
self.assertEqual(len(module_partitions), 1)
self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)
@ -138,7 +198,9 @@ class TestSourceMatcher(JitTestCase):
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
gm.graph.eliminate_dead_code()
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu])
module_partitions = get_source_partitions(
gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu]
)
self.assertEqual(len(module_partitions), 2)
self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)

View File

@ -4,8 +4,9 @@ import os
import sys
import torch
from torch.fx import symbolic_trace, subgraph_rewriter
from torch.fx import subgraph_rewriter, symbolic_trace
from torch.fx.annotate import annotate
# Make the helper files in test/ importable
from torch.fx.experimental.rewriter import RewritingTracer
@ -13,10 +14,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_fx.py TESTNAME\n\n"
"instead.")
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_fx.py TESTNAME\n\n"
"instead."
)
@torch.fx.wrap
def wrapped_gemm_bias_mul(a, b, bias):
@ -24,14 +28,15 @@ def wrapped_gemm_bias_mul(a, b, bias):
mul_res = lin_res * a
return lin_res, mul_res
@torch.fx.wrap
def wrapped_gemm_bias_mul_with_c(a, b, bias, c):
lin_res = torch.nn.functional.linear(a, b, bias=bias)
mul_res = lin_res * c
return lin_res, mul_res
class TestSubgraphRewriter(JitTestCase):
class TestSubgraphRewriter(JitTestCase):
def test_subgraph_rewriter_preserves_logic(self):
class M(torch.nn.Module):
def forward(self, x):
@ -110,7 +115,9 @@ class TestSubgraphRewriter(JitTestCase):
x = torch.randn(1, 5)
matches = subgraph_rewriter.replace_pattern_with_filters(traced, pattern, replacement, [])
matches = subgraph_rewriter.replace_pattern_with_filters(
traced, pattern, replacement, []
)
traced.graph.lint()
@ -297,7 +304,9 @@ class TestSubgraphRewriter(JitTestCase):
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self):
def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(
self,
):
class M(torch.nn.Module):
def forward(self, x):
y = torch.relu(x)
@ -326,7 +335,9 @@ class TestSubgraphRewriter(JitTestCase):
test_outs = traced.forward(x)
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(self):
def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(
self,
):
class M(torch.nn.Module):
def forward(self, x, w1, w2, b1, b2):
m0 = torch.cat([w1, w2])
@ -385,6 +396,7 @@ class TestSubgraphRewriter(JitTestCase):
Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@ -483,7 +495,6 @@ class TestSubgraphRewriter(JitTestCase):
self.assertEqual(type(submod), torch.nn.ReLU)
def test_subgraph_rewriter_annotations_int(self):
class M1(torch.nn.Module):
def forward(self, x):
y: int = x
@ -500,12 +511,11 @@ class TestSubgraphRewriter(JitTestCase):
module = M2()
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
if n.op == 'placeholder':
if n.op == "placeholder":
assert n.type == int
assert m.type == int
def test_subgraph_rewriter_replace_consecutive_submodules(self):
def f(x):
x = torch.sigmoid(x)
x = torch.sigmoid(x)
@ -536,7 +546,6 @@ class TestSubgraphRewriter(JitTestCase):
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_with_overlapping_matches(self):
def f(x):
x = torch.sigmoid(x)
x = torch.sigmoid(x)
@ -569,7 +578,6 @@ class TestSubgraphRewriter(JitTestCase):
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_replace_with_multiple_outputs(self):
def f(x):
y = torch.sigmoid(x)
z = torch.relu(x)
@ -602,7 +610,6 @@ class TestSubgraphRewriter(JitTestCase):
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_replace_with_duplicated_outputs(self):
def f(x1, x2):
x = x1 - x2
y = torch.sigmoid(x)
@ -670,7 +677,6 @@ class TestSubgraphRewriter(JitTestCase):
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_call_method(self):
class M(torch.nn.Module):
def forward(self, x):
x = x.dequantize()
@ -701,7 +707,6 @@ class TestSubgraphRewriter(JitTestCase):
self.assertEqual(ref_outs, test_outs)
def test_subgraph_rewriter_nodes_with_kwargs(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -737,7 +742,6 @@ class TestSubgraphRewriter(JitTestCase):
self.assertTrue(found_repalcement_node)
def test_subgraph_rewriter_local_revert(self):
# Following model will have 3 anchors as the matching candidate with the given pattern
# Anchor 1 and 3 is a real match, but anchor 2 is not.
# The subgraph rewriter should be able to revert the changes made while matching anchor 2.
@ -763,9 +767,7 @@ class TestSubgraphRewriter(JitTestCase):
# potential match at anchor 1
mul_res_1 = in1 * lin_res_2
sum_res_1 = mul_res_1 + in1
lin_res_3 = torch.nn.functional.linear(
sum_res_1, self.w2, bias=self.b2
)
lin_res_3 = torch.nn.functional.linear(sum_res_1, self.w2, bias=self.b2)
sigmoid_res_1 = torch.sigmoid(lin_res_3)
# potential match at anchor 2
mul_res_2 = lin_res_3 * sigmoid_res_1
@ -791,9 +793,8 @@ class TestSubgraphRewriter(JitTestCase):
traced = symbolic_trace(M())
matches = subgraph_rewriter.replace_pattern(
traced,
gemm_bias_mul_pattern_with_c,
gemm_bias_mul_replacement_with_c)
traced, gemm_bias_mul_pattern_with_c, gemm_bias_mul_replacement_with_c
)
self.assertEqual(len(matches), 2)
@ -834,7 +835,7 @@ class TestSubgraphRewriter(JitTestCase):
return x
def second_input_is_scalar(match, original_graph, pattern_graph):
""" check the node that's matched to the second input of the pattern graph
"""check the node that's matched to the second input of the pattern graph
is a scalar number
"""
input_idx = 0
@ -848,19 +849,21 @@ class TestSubgraphRewriter(JitTestCase):
return True
def check_replacement_nodes(self, traced, matches):
replacement_nodes_in_graph = [node for node in traced.graph.nodes if node.target == torch.mul]
replacement_nodes_in_graph = [
node for node in traced.graph.nodes if node.target == torch.mul
]
replacement_nodes_in_res = [r for m in matches for r in m.replacements]
self.assertEqual(len(replacement_nodes_in_graph), len(replacement_nodes_in_res))
self.assertEqual(
len(replacement_nodes_in_graph), len(replacement_nodes_in_res)
)
self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res)
return len(replacement_nodes_in_graph)
# match without filter, should find 2 match
traced = symbolic_trace(M())
matches = subgraph_rewriter.replace_pattern_with_filters(
traced,
BinaryOpScalarReLUPattern,
BinaryOpScalarReLUReplacement,
None)
traced, BinaryOpScalarReLUPattern, BinaryOpScalarReLUReplacement, None
)
self.assertEqual(len(matches), 2)
self.assertEqual(check_replacement_nodes(self, traced, matches), 2)
@ -870,7 +873,8 @@ class TestSubgraphRewriter(JitTestCase):
traced,
BinaryOpScalarReLUPattern,
BinaryOpScalarReLUReplacement,
[second_input_is_scalar])
[second_input_is_scalar],
)
self.assertEqual(len(matches), 1)
self.assertEqual(check_replacement_nodes(self, traced, matches), 1)
@ -890,10 +894,13 @@ class TestSubgraphRewriter(JitTestCase):
self.assertEqual(len(matches), 1)
self.assertExpectedInline(traced.code.strip(), """\
self.assertExpectedInline(
traced.code.strip(),
"""\
def forward(self, x):
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(x, [3, 4], [1, 2]); x = None
return _reshape_alias_copy_default_1""") # noqa: B950
return _reshape_alias_copy_default_1""",
) # noqa: B950
def test_replacement_with_attrs(self):
class M(torch.nn.Module):
@ -928,11 +935,15 @@ def forward(self, x):
def test_matching_variable_arguments(self):
class M(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.max_pool2d_with_indices.default(x, [2, 2], stride=[2, 2])
return torch.ops.aten.max_pool2d_with_indices.default(
x, [2, 2], stride=[2, 2]
)
def pattern(x, kernel_size, stride):
# default padding is [0, 0]
return torch.ops.aten.max_pool2d_with_indices.default(x, kernel_size, stride, padding=[0, 0])
return torch.ops.aten.max_pool2d_with_indices.default(
x, kernel_size, stride, padding=[0, 0]
)
traced = symbolic_trace(M())
matches = subgraph_rewriter.replace_pattern(traced, pattern, pattern)
@ -951,12 +962,20 @@ def forward(self, x):
return torch.sub(torch.mul(x, y), y)
traced = symbolic_trace(M())
matches = subgraph_rewriter.replace_pattern_with_filters(traced, pattern, replacement)
matches = subgraph_rewriter.replace_pattern_with_filters(
traced, pattern, replacement
)
def check_replacement_nodes(self, traced, matches):
replacement_nodes_in_graph = [node for node in traced.graph.nodes if node.target in {torch.sub, torch.mul}]
replacement_nodes_in_graph = [
node
for node in traced.graph.nodes
if node.target in {torch.sub, torch.mul}
]
replacement_nodes_in_res = [r for m in matches for r in m.replacements]
self.assertEqual(len(replacement_nodes_in_graph), len(replacement_nodes_in_res))
self.assertEqual(
len(replacement_nodes_in_graph), len(replacement_nodes_in_res)
)
self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res)
return len(replacement_nodes_in_graph)

File diff suppressed because it is too large Load Diff