mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable UFMT on all of test/fx
(#123622)
Partially addresses #123062 Ran lintrunner on: - `test/fx` with command: ```bash lintrunner -a --take UFMT --all-files ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/123622 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
3b3962f7b3
commit
c96bd3de06
@ -1178,21 +1178,6 @@ exclude_patterns = [
|
||||
'test/functorch/test_vmap.py',
|
||||
'test/functorch/test_vmap_registrations.py',
|
||||
'test/functorch/xfail_suggester.py',
|
||||
'test/fx/named_tup.py',
|
||||
'test/fx/quantization.py',
|
||||
'test/fx/test_common_passes.py',
|
||||
'test/fx/test_cse_pass.py',
|
||||
'test/fx/test_dce_pass.py',
|
||||
'test/fx/test_future.py',
|
||||
'test/fx/test_fx_const_fold.py',
|
||||
'test/fx/test_fx_param_shape_control_flow.py',
|
||||
'test/fx/test_gradual_type.py',
|
||||
'test/fx/test_matcher_utils.py',
|
||||
'test/fx/test_pass_infra.py',
|
||||
'test/fx/test_source_matcher_utils.py',
|
||||
'test/fx/test_subgraph_rewriter.py',
|
||||
'test/fx/test_z3_gradual_types.py',
|
||||
'test/fx/test_fx_split.py',
|
||||
'test/jit/__init__.py',
|
||||
'test/jit/_imported_class_test/__init__.py',
|
||||
'test/jit/_imported_class_test/bar.py',
|
||||
|
@ -2,6 +2,7 @@ from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MyNamedTup(NamedTuple):
|
||||
i : torch.Tensor
|
||||
f : torch.Tensor
|
||||
i: torch.Tensor
|
||||
f: torch.Tensor
|
||||
|
@ -1,20 +1,24 @@
|
||||
r'''
|
||||
r"""
|
||||
**This file is EXPERIMENTAL and is mostly used for testing purposes! Do not
|
||||
rely on it for anything!**
|
||||
'''
|
||||
"""
|
||||
import operator
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
from torch.fx.graph import map_arg
|
||||
from torch.fx.proxy import Proxy
|
||||
import sys
|
||||
import torch
|
||||
from torch.nn.utils import fuse_conv_bn_weights
|
||||
import operator
|
||||
from typing import Optional
|
||||
|
||||
# can be a
|
||||
# module type, a builtin function, or a string to match target
|
||||
|
||||
def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps):
|
||||
|
||||
def _minmax_scale_zeropoint(
|
||||
min_val, max_val, qmin=-127, qmax=128, eps=torch.finfo(torch.float32).eps
|
||||
):
|
||||
min_val = min(0.0, min_val)
|
||||
max_val = max(0.0, max_val)
|
||||
if max_val == min_val:
|
||||
@ -28,9 +32,10 @@ def _minmax_scale_zeropoint(min_val, max_val, qmin=-127, qmax=128, eps=torch.fin
|
||||
zero_point = int(zero_point)
|
||||
return scale, zero_point
|
||||
|
||||
|
||||
class MinMaxObserver:
|
||||
def __init__(self, quantizer, node):
|
||||
self.min, self.max = float('inf'), float('-inf')
|
||||
self.min, self.max = float("inf"), float("-inf")
|
||||
self.all_tensors = True
|
||||
|
||||
def observe(self, node, env):
|
||||
@ -44,6 +49,7 @@ class MinMaxObserver:
|
||||
def scale_zeropoint(self):
|
||||
return _minmax_scale_zeropoint(self.min, self.max, qmin=0, qmax=255)
|
||||
|
||||
|
||||
class NoObserver:
|
||||
def __init__(self, quantizer, node):
|
||||
pass
|
||||
@ -51,11 +57,15 @@ class NoObserver:
|
||||
def observe(self, node, env):
|
||||
pass
|
||||
|
||||
|
||||
_DEFAULT_QUANTIZATION_PATTERNS = {}
|
||||
|
||||
|
||||
def register_pattern(pattern):
|
||||
def insert(fn):
|
||||
_DEFAULT_QUANTIZATION_PATTERNS[pattern] = fn
|
||||
return fn
|
||||
|
||||
return insert
|
||||
|
||||
|
||||
@ -66,12 +76,19 @@ class Add(MinMaxObserver):
|
||||
return NotImplemented
|
||||
scale, zeropoint = self.scale_zeropoint()
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.add, load_arg(node.args), {'scale': scale, 'zero_point': zeropoint})
|
||||
"call_function",
|
||||
torch.ops.quantized.add,
|
||||
load_arg(node.args),
|
||||
{"scale": scale, "zero_point": zeropoint},
|
||||
)
|
||||
|
||||
|
||||
class Relu(NoObserver):
|
||||
def quantize(self, quantizer, node, load_arg):
|
||||
return torch.relu(load_arg(node.args[0])) # torch.relu works directly on quantized tensors?
|
||||
return torch.relu(
|
||||
load_arg(node.args[0])
|
||||
) # torch.relu works directly on quantized tensors?
|
||||
|
||||
|
||||
# these ops have quantized equivalents that do not need any extra information
|
||||
@register_pattern(torch.nn.ReLU)
|
||||
@ -82,15 +99,24 @@ class CopyNode(NoObserver):
|
||||
def quantize(self, quantizer, node, load_arg):
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg)
|
||||
|
||||
|
||||
class IdentityModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
# handle conv, maybe followed by bn, maybe followed by relu
|
||||
@register_pattern(torch.nn.modules.conv.Conv2d)
|
||||
@register_pattern((torch.nn.ReLU, torch.nn.modules.conv.Conv2d))
|
||||
@register_pattern((torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d))
|
||||
@register_pattern((torch.nn.ReLU, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)))
|
||||
@register_pattern(
|
||||
(torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d)
|
||||
)
|
||||
@register_pattern(
|
||||
(
|
||||
torch.nn.ReLU,
|
||||
(torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.conv.Conv2d),
|
||||
)
|
||||
)
|
||||
class ConvNormRelu(MinMaxObserver):
|
||||
def __init__(self, quantizer, node):
|
||||
super().__init__(quantizer, node)
|
||||
@ -112,21 +138,41 @@ class ConvNormRelu(MinMaxObserver):
|
||||
|
||||
if self.bn_node is not None:
|
||||
weight, bias = fuse_conv_bn_weights(
|
||||
weight, bias, self.bn.running_mean, self.bn.running_var,
|
||||
self.bn.eps, self.bn.weight, self.bn.bias)
|
||||
weight,
|
||||
bias,
|
||||
self.bn.running_mean,
|
||||
self.bn.running_var,
|
||||
self.bn.eps,
|
||||
self.bn.weight,
|
||||
self.bn.bias,
|
||||
)
|
||||
|
||||
min_val, max_val = float(weight.min()), float(weight.max())
|
||||
|
||||
act_scale, act_zp = self.scale_zeropoint()
|
||||
|
||||
weight_scale, weight_zp = _minmax_scale_zeropoint(min_val, max_val)
|
||||
qweight = torch.quantize_per_tensor(weight, weight_scale, weight_zp, torch.qint8)
|
||||
qweight = torch.quantize_per_tensor(
|
||||
weight, weight_scale, weight_zp, torch.qint8
|
||||
)
|
||||
|
||||
ctor = torch.ao.nn.intrinsic.quantized.ConvReLU2d if self.relu_node is not None else torch.ao.nn.quantized.Conv2d
|
||||
ctor = (
|
||||
torch.ao.nn.intrinsic.quantized.ConvReLU2d
|
||||
if self.relu_node is not None
|
||||
else torch.ao.nn.quantized.Conv2d
|
||||
)
|
||||
|
||||
qconv = ctor(mod.in_channels, mod.out_channels, mod.kernel_size,
|
||||
mod.stride, mod.padding, mod.dilation, mod.groups,
|
||||
mod.bias is not None, mod.padding_mode)
|
||||
qconv = ctor(
|
||||
mod.in_channels,
|
||||
mod.out_channels,
|
||||
mod.kernel_size,
|
||||
mod.stride,
|
||||
mod.padding,
|
||||
mod.dilation,
|
||||
mod.groups,
|
||||
mod.bias is not None,
|
||||
mod.padding_mode,
|
||||
)
|
||||
|
||||
qconv.set_weight_bias(qweight, bias)
|
||||
qconv.scale = float(act_scale)
|
||||
@ -139,24 +185,31 @@ class ConvNormRelu(MinMaxObserver):
|
||||
# try to call it, so replace with something that does nothing.
|
||||
setattr(quantizer.modules[parent_name], bn_name, IdentityModule())
|
||||
|
||||
return quantizer.quantized_graph.create_node('call_module', self.conv_node.target, (load_arg(self.conv_node.args[0]),), {})
|
||||
return quantizer.quantized_graph.create_node(
|
||||
"call_module",
|
||||
self.conv_node.target,
|
||||
(load_arg(self.conv_node.args[0]),),
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
# turn foo.bar -> ['foo', 'bar']
|
||||
def _parent_name(target):
|
||||
r = target.rsplit('.', 1)
|
||||
r = target.rsplit(".", 1)
|
||||
if len(r) == 1:
|
||||
return '', r[0]
|
||||
return "", r[0]
|
||||
else:
|
||||
return r[0], r[1]
|
||||
|
||||
|
||||
|
||||
class DefaultQuant(MinMaxObserver):
|
||||
def quantize(self, input):
|
||||
assert self.all_tensors
|
||||
scale, zeropoint = self.scale_zeropoint()
|
||||
return torch.quantize_per_tensor(Proxy(input), scale, zeropoint, torch.quint8).node
|
||||
return torch.quantize_per_tensor(
|
||||
Proxy(input), scale, zeropoint, torch.quint8
|
||||
).node
|
||||
|
||||
|
||||
def matches(modules, node, pattern, max_uses=sys.maxsize):
|
||||
if isinstance(pattern, tuple):
|
||||
@ -169,12 +222,12 @@ def matches(modules, node, pattern, max_uses=sys.maxsize):
|
||||
return False
|
||||
|
||||
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
|
||||
if node.op != 'call_module':
|
||||
if node.op != "call_module":
|
||||
return False
|
||||
if not isinstance(modules[node.target], self_match):
|
||||
return False
|
||||
elif callable(self_match):
|
||||
if node.op != 'call_function' or node.target is not self_match:
|
||||
if node.op != "call_function" or node.target is not self_match:
|
||||
return False
|
||||
elif node.target != self_match:
|
||||
return False
|
||||
@ -185,11 +238,16 @@ def matches(modules, node, pattern, max_uses=sys.maxsize):
|
||||
if len(arg_matches) != len(node.args):
|
||||
return False
|
||||
|
||||
return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
||||
return all(
|
||||
matches(modules, node, arg_match, max_uses=1)
|
||||
for node, arg_match in zip(node.args, arg_matches)
|
||||
)
|
||||
|
||||
|
||||
class Quantizer:
|
||||
def __init__(self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant):
|
||||
def __init__(
|
||||
self, mod, patterns=_DEFAULT_QUANTIZATION_PATTERNS, quant_ctor=DefaultQuant
|
||||
):
|
||||
self.root = mod
|
||||
self.graph = mod.graph
|
||||
self.quant_ctor = quant_ctor
|
||||
@ -205,8 +263,6 @@ class Quantizer:
|
||||
# initialize an quant_ctor object for each
|
||||
self.quants = self._find_quants(quant_ctor)
|
||||
|
||||
|
||||
|
||||
def observe(self, args):
|
||||
# most of this function is just an interpreter for the graph
|
||||
# it would be possible to put this in some abstraction, but
|
||||
@ -220,21 +276,23 @@ class Quantizer:
|
||||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
|
||||
output_node : Optional[Node] = None
|
||||
output_node: Optional[Node] = None
|
||||
for node in self.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
if node.op == "placeholder":
|
||||
result = next(args_iter)
|
||||
elif node.op == 'get_attr':
|
||||
elif node.op == "get_attr":
|
||||
result = self.state_dict[node.target]
|
||||
elif node.op == 'call_function':
|
||||
elif node.op == "call_function":
|
||||
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))
|
||||
elif node.op == 'call_method':
|
||||
elif node.op == "call_method":
|
||||
self_obj, *args = load_arg(node.args)
|
||||
kwargs = load_arg(node.kwargs)
|
||||
result = getattr(self_obj, node.target)(*args, **kwargs)
|
||||
elif node.op == 'call_module':
|
||||
result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs))
|
||||
elif node.op == 'output':
|
||||
elif node.op == "call_module":
|
||||
result = self.modules[node.target](
|
||||
*load_arg(node.args), **load_arg(node.kwargs)
|
||||
)
|
||||
elif node.op == "output":
|
||||
return load_arg(node.args[0])
|
||||
|
||||
env[node.name] = result
|
||||
@ -244,7 +302,7 @@ class Quantizer:
|
||||
if node.name in self.quants:
|
||||
self.quants[node.name].observe(node, env)
|
||||
|
||||
raise RuntimeError('Graph had no output node!')
|
||||
raise RuntimeError("Graph had no output node!")
|
||||
|
||||
def quantize(self):
|
||||
self.quantized_graph = Graph()
|
||||
@ -268,17 +326,26 @@ class Quantizer:
|
||||
return load_arg(n, quantized=False)
|
||||
else:
|
||||
return copy_recursive(n)
|
||||
r = env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False))
|
||||
|
||||
r = env[node.name] = self.quantized_graph.node_copy(
|
||||
node, lambda n: load_arg(n, quantized=False)
|
||||
)
|
||||
return r
|
||||
|
||||
for node in self.graph.nodes:
|
||||
root_node, obj = self.matches.get(node.name, (None, None))
|
||||
if root_node is None:
|
||||
# not quantized just copy it
|
||||
env[node.name] = self.quantized_graph.node_copy(node, lambda n: load_arg(n, quantized=False))
|
||||
env[node.name] = self.quantized_graph.node_copy(
|
||||
node, lambda n: load_arg(n, quantized=False)
|
||||
)
|
||||
|
||||
elif root_node is node:
|
||||
r = obj.quantize(self, node, lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)))
|
||||
r = obj.quantize(
|
||||
self,
|
||||
node,
|
||||
lambda a: map_arg(a, lambda n: load_arg(n, quantized=True)),
|
||||
)
|
||||
if r is NotImplemented:
|
||||
# quantizer choose to to quantize the node take the entire match, and just copy it over
|
||||
env[node.name] = copy_recursive(node)
|
||||
@ -318,6 +385,7 @@ class Quantizer:
|
||||
# say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
|
||||
if n.name not in quants:
|
||||
quants[n.name] = quant_ctor(self, n)
|
||||
|
||||
for node in self.graph.nodes:
|
||||
if node.name in self.matches:
|
||||
map_arg(node.args, visit_arg)
|
||||
|
@ -1,14 +1,19 @@
|
||||
# Owner(s): ["oncall: fx"]
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from torch.fx.passes.dialect.common.cse_pass import CSEPass
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, parametrize, instantiate_parametrized_tests, run_tests)
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.passes.dialect.common.cse_pass import CSEPass
|
||||
from torch.fx.graph_module import GraphModule
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
import itertools
|
||||
|
||||
def FactoryFunctionCall(x, device):
|
||||
y = torch.full(x.shape, 3, device=device)
|
||||
@ -62,12 +67,14 @@ def MutationMetadata(x):
|
||||
|
||||
|
||||
Passes = [CSEPass]
|
||||
Test_Cases = [TakeList,
|
||||
ReturnList,
|
||||
Mutation,
|
||||
MutationInput,
|
||||
MutationMetadata,
|
||||
MutationTorchTensorCall]
|
||||
Test_Cases = [
|
||||
TakeList,
|
||||
ReturnList,
|
||||
Mutation,
|
||||
MutationInput,
|
||||
MutationMetadata,
|
||||
MutationTorchTensorCall,
|
||||
]
|
||||
Factory_Test_Cases = [FactoryFunctionCall, MutationFactory]
|
||||
Devices = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
@ -76,12 +83,14 @@ if torch.cuda.is_available():
|
||||
|
||||
def name_fn(common_pass, f, device):
|
||||
"""Names parameterized test cases."""
|
||||
return f'{type(common_pass()).__name__}_{f.__name__}_{device}'
|
||||
return f"{type(common_pass()).__name__}_{f.__name__}_{device}"
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestCommonPass(TestCase):
|
||||
|
||||
@parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn)
|
||||
@parametrize(
|
||||
"common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn
|
||||
)
|
||||
def test_correctness(self, common_pass, f, device):
|
||||
inp = torch.randn(10, device=device)
|
||||
|
||||
@ -98,8 +107,11 @@ class TestCommonPass(TestCase):
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
@parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices), name_fn)
|
||||
@parametrize(
|
||||
"common_pass,f,device",
|
||||
itertools.product(Passes, Factory_Test_Cases, Devices),
|
||||
name_fn,
|
||||
)
|
||||
def test_correctness_factory(self, common_pass, f, device):
|
||||
inp = torch.randn(10, device=device)
|
||||
traced_m = make_fx(f)(inp, device)
|
||||
@ -116,5 +128,5 @@ class TestCommonPass(TestCase):
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,19 +1,19 @@
|
||||
# Owner(s): ["oncall: fx"]
|
||||
|
||||
import torch
|
||||
import random
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests)
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.passes.dialect.common.cse_pass import CSEPass, get_CSE_banned_ops
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
import random
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
banned_ops = get_CSE_banned_ops()
|
||||
P_default = CSEPass(banned_ops=banned_ops)
|
||||
|
||||
|
||||
def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
|
||||
"""
|
||||
check if the CSE modified graph of ``f``
|
||||
@ -47,34 +47,50 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None):
|
||||
old_num_nodes = len(fx_g.graph.nodes)
|
||||
new_num_nodes = len(new_graph.nodes)
|
||||
|
||||
assert (new_num_nodes < old_num_nodes) == modified, "modified should be True if the number of nodes decrease"
|
||||
assert (
|
||||
new_num_nodes < old_num_nodes
|
||||
) == modified, "modified should be True if the number of nodes decrease"
|
||||
|
||||
if delta == -1:
|
||||
self.assertTrue(old_num_nodes >= new_num_nodes, (
|
||||
f"number of nodes increased {old_num_nodes}, {new_num_nodes}"))
|
||||
self.assertTrue(
|
||||
old_num_nodes >= new_num_nodes,
|
||||
(f"number of nodes increased {old_num_nodes}, {new_num_nodes}"),
|
||||
)
|
||||
else:
|
||||
self.assertTrue(old_num_nodes == new_num_nodes + delta, (
|
||||
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"))
|
||||
self.assertTrue(
|
||||
old_num_nodes == new_num_nodes + delta,
|
||||
(
|
||||
f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}"
|
||||
),
|
||||
)
|
||||
|
||||
# a second pass should not reduce more nodes
|
||||
res = P(new_g)
|
||||
pass_2_graph = res.graph_module.graph
|
||||
pass_2_num_nodes = len(pass_2_graph.nodes)
|
||||
self.assertTrue(pass_2_num_nodes == new_num_nodes, (
|
||||
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"))
|
||||
self.assertTrue(
|
||||
pass_2_num_nodes == new_num_nodes,
|
||||
(
|
||||
f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}"
|
||||
),
|
||||
)
|
||||
|
||||
# check correctness
|
||||
if check_val:
|
||||
true_result = fx_g(t)
|
||||
our_result = new_g(t)
|
||||
if true_result is None: # both return None
|
||||
self.assertTrue(our_result is None, f"true result is None, CSE result is {our_result}")
|
||||
self.assertTrue(
|
||||
our_result is None, f"true result is None, CSE result is {our_result}"
|
||||
)
|
||||
else: # results returned are the same
|
||||
self.assertTrue(torch.all(true_result == our_result), (
|
||||
f"results are different {true_result}, {our_result}")) # check results are the same
|
||||
self.assertTrue(
|
||||
torch.all(true_result == our_result),
|
||||
(f"results are different {true_result}, {our_result}"),
|
||||
) # check results are the same
|
||||
|
||||
|
||||
class TestCSEPass(TestCase):
|
||||
|
||||
def test_nochange(self):
|
||||
def f(x):
|
||||
a = x + 1
|
||||
@ -82,16 +98,17 @@ class TestCSEPass(TestCase):
|
||||
a = x
|
||||
d = x + a
|
||||
return b + d
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 0)
|
||||
|
||||
def test_empty(self):
|
||||
def f(x):
|
||||
pass
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 0)
|
||||
|
||||
|
||||
def test_immutable_list_type(self):
|
||||
def f(x):
|
||||
a = x.sum(dim=1)
|
||||
@ -99,6 +116,7 @@ class TestCSEPass(TestCase):
|
||||
c = x.sum()
|
||||
d = x.sum()
|
||||
return a + b + c + d
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 2)
|
||||
|
||||
@ -109,6 +127,7 @@ class TestCSEPass(TestCase):
|
||||
c = x.sum(dim=1)
|
||||
d = x.sum(dim=1)
|
||||
return a + b + c + d
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 2)
|
||||
|
||||
@ -119,6 +138,7 @@ class TestCSEPass(TestCase):
|
||||
c = a + a
|
||||
d = b + b
|
||||
return c + d
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 2)
|
||||
|
||||
@ -129,6 +149,7 @@ class TestCSEPass(TestCase):
|
||||
c = a + a
|
||||
d = b + b
|
||||
return c + d
|
||||
|
||||
t = torch.randn(1)
|
||||
check(self, f, t, 3)
|
||||
|
||||
@ -139,6 +160,7 @@ class TestCSEPass(TestCase):
|
||||
c = x.sum(dim=1, keepdim=False)
|
||||
d = x.sum(dim=1)
|
||||
return a + b + c + d
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 3)
|
||||
|
||||
@ -149,6 +171,7 @@ class TestCSEPass(TestCase):
|
||||
c = x.sum(dim=1, keepdim=True)
|
||||
d = x.sum(dim=1)
|
||||
return a + b + c + d
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 2)
|
||||
|
||||
@ -159,6 +182,7 @@ class TestCSEPass(TestCase):
|
||||
c = x.sum()
|
||||
d = x.sum()
|
||||
return a + b + c + d
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 3)
|
||||
|
||||
@ -167,6 +191,7 @@ class TestCSEPass(TestCase):
|
||||
a = torch.cat((x, x))
|
||||
b = torch.cat((x, x))
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 1)
|
||||
|
||||
@ -175,12 +200,14 @@ class TestCSEPass(TestCase):
|
||||
a = torch.ones_like(x)
|
||||
b = torch.ones_like(x)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 1)
|
||||
|
||||
"""
|
||||
Generate function with random ops and check if the result is the same
|
||||
"""
|
||||
|
||||
def test_random(self):
|
||||
def f(x):
|
||||
vals = [x]
|
||||
@ -201,6 +228,7 @@ class TestCSEPass(TestCase):
|
||||
"""
|
||||
Test that banned list ban ops as expected.
|
||||
"""
|
||||
|
||||
def test_banned_list(self):
|
||||
def f(x):
|
||||
a = x + 1
|
||||
@ -217,6 +245,7 @@ class TestCSEPass(TestCase):
|
||||
a = torch.rand_like(x)
|
||||
b = torch.rand_like(x)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 0, check_val=False)
|
||||
|
||||
@ -225,9 +254,10 @@ class TestCSEPass(TestCase):
|
||||
a = torch.randn(4)
|
||||
b = torch.randn(4)
|
||||
return a + b
|
||||
|
||||
t = torch.randn(2, 2)
|
||||
check(self, f, t, 0, check_val=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: fx"]
|
||||
|
||||
from typing import Set, Type
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
|
@ -1,34 +1,42 @@
|
||||
# Owner(s): ["module: fx"]
|
||||
|
||||
from __future__ import annotations # type: ignore[attr-defined]
|
||||
import torch
|
||||
from __future__ import annotations # type: ignore[attr-defined]
|
||||
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
|
||||
class A:
|
||||
def __call__(self, x: torch.Tensor):
|
||||
return torch.add(x, x)
|
||||
|
||||
|
||||
# No forward references
|
||||
class M1(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
|
||||
return a(x)
|
||||
|
||||
|
||||
# Forward references
|
||||
class M2(torch.nn.Module):
|
||||
def forward(self, x: torch.Tensor, a: A) -> torch.Tensor:
|
||||
return a(x)
|
||||
|
||||
|
||||
# Non-torch annotation with no internal forward references
|
||||
class M3(torch.nn.Module):
|
||||
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
|
||||
return a(x[0])
|
||||
|
||||
|
||||
# Non-torch annotation with internal forward references
|
||||
class M4(torch.nn.Module):
|
||||
def forward(self, x: typing.List[torch.Tensor], a: A) -> torch.Tensor:
|
||||
return a(x[0])
|
||||
|
||||
|
||||
x = torch.rand(2, 3)
|
||||
|
||||
ref = torch.add(x, x)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Owner(s): ["module: fx"]
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
@ -21,6 +22,7 @@ class MyModuleBase(torch.nn.Module):
|
||||
def no_relu(self):
|
||||
raise Exception("not implemented")
|
||||
|
||||
|
||||
class MyModuleParamShape(MyModuleBase):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
@ -72,7 +74,6 @@ class MyModuleParamNumEl(MyModuleBase):
|
||||
return self.param.numel() < 10 * 3
|
||||
|
||||
|
||||
|
||||
class MyModuleParamNElement(MyModuleBase):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
@ -82,43 +83,49 @@ class MyModuleParamNElement(MyModuleBase):
|
||||
return self.param.nelement() < 10 * 3
|
||||
|
||||
|
||||
|
||||
class TestConstParamShapeInControlFlow(TestCase):
|
||||
|
||||
def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
|
||||
"""
|
||||
Verify one module only does a mm op while the other
|
||||
performs both mm and relu ops in cascade
|
||||
"""
|
||||
x = torch.randn(10, 5)
|
||||
torch.testing.assert_close(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
|
||||
torch.testing.assert_close(
|
||||
mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix())
|
||||
)
|
||||
tracer = torch.fx.Tracer(param_shapes_constant=True)
|
||||
traced_graph = tracer.trace(mm_only_mod)
|
||||
|
||||
# verify the graph module calculates the same result
|
||||
graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
|
||||
torch.testing.assert_close(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
|
||||
|
||||
torch.testing.assert_close(
|
||||
graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix())
|
||||
)
|
||||
|
||||
# Make a new module with different parameter shape to go down the different
|
||||
# code path
|
||||
x = torch.randn(10, 15)
|
||||
torch.testing.assert_close(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
|
||||
torch.testing.assert_close(
|
||||
relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
|
||||
)
|
||||
|
||||
tracer2 = torch.fx.Tracer(param_shapes_constant=True)
|
||||
traced_graph2 = tracer2.trace(relu_mod)
|
||||
|
||||
# verify the graph module calculates the same result
|
||||
graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
|
||||
torch.testing.assert_close(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
|
||||
|
||||
torch.testing.assert_close(
|
||||
graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix()))
|
||||
)
|
||||
|
||||
graph1_node_targets = [n.target for n in traced_graph.nodes]
|
||||
graph2_node_targets = [n.target for n in traced_graph2.nodes]
|
||||
|
||||
# the second graph has an exta relu function call node
|
||||
assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
|
||||
assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
|
||||
assert (
|
||||
torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
|
||||
)
|
||||
|
||||
def test_param_shape_const(self):
|
||||
mymod = MyModuleParamShape(in_channels=5)
|
||||
@ -151,5 +158,5 @@ class TestConstParamShapeInControlFlow(TestCase):
|
||||
self.verify_mm_relu_mods(mymod, mymod2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Owner(s): ["module: fx"]
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple, Dict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.passes.split_utils import split_by_tags
|
||||
@ -40,6 +40,7 @@ class TestFXSplit(TestCase):
|
||||
self.assertIn("name", node.meta)
|
||||
self.assertEqual(node.meta["name"], node.name)
|
||||
|
||||
|
||||
class TestSplitByTags(TestCase):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
@ -151,6 +152,7 @@ class TestSplitByTags(TestCase):
|
||||
f"{orig_to_split_fqn_mapping=}",
|
||||
)
|
||||
|
||||
|
||||
class TestSplitOutputType(TestCase):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -1,18 +1,22 @@
|
||||
# Owner(s): ["module: fx"]
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.experimental.unify_refinements import infer_symbolic_types
|
||||
from torch.fx.experimental.refinement_types import Equality
|
||||
from torch.fx.tensor_type import TensorType, Dyn, is_consistent, is_more_precise
|
||||
from torch.fx.annotate import annotate
|
||||
from torch.fx.experimental.graph_gradual_typechecker import GraphTypeChecker, broadcast_types, Refine
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
import sympy
|
||||
import torch
|
||||
from torch.fx import GraphModule, symbolic_trace
|
||||
from torch.fx.annotate import annotate
|
||||
from torch.fx.experimental.graph_gradual_typechecker import (
|
||||
broadcast_types,
|
||||
GraphTypeChecker,
|
||||
Refine,
|
||||
)
|
||||
from torch.fx.experimental.refinement_types import Equality
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
from torch.fx.experimental.unify_refinements import infer_symbolic_types
|
||||
from torch.fx.passes.shape_prop import ShapeProp
|
||||
from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
try:
|
||||
@ -23,24 +27,33 @@ except ImportError:
|
||||
HAS_TORCHVISION = False
|
||||
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
return torch.nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
|
||||
class AnnotationsTest(TestCase):
|
||||
|
||||
def test_annotations(self):
|
||||
"""
|
||||
Test type annotations in the forward function.
|
||||
The annotation should appear in the n.graph
|
||||
where n is the corresponding node in the resulting graph.
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self,
|
||||
x: TensorType((1, 2, 3, Dyn)),
|
||||
y: Dyn,
|
||||
z: TensorType[Dyn, 3, Dyn]):
|
||||
def forward(
|
||||
self, x: TensorType((1, 2, 3, Dyn)), y: Dyn, z: TensorType[Dyn, 3, Dyn]
|
||||
):
|
||||
return torch.add(x, y) + z
|
||||
|
||||
module = M()
|
||||
@ -50,20 +63,19 @@ class AnnotationsTest(TestCase):
|
||||
expected_iter = iter(expected_ph_types)
|
||||
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == next(expected_iter)
|
||||
|
||||
def test_annotate(self):
|
||||
class M(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
y = annotate(x, TensorType((1, 2, 3, Dyn)))
|
||||
return torch.add(x, y)
|
||||
|
||||
module = M()
|
||||
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == TensorType((1, 2, 3, Dyn))
|
||||
|
||||
def test_consistency(self):
|
||||
@ -83,7 +95,9 @@ class AnnotationsTest(TestCase):
|
||||
self.assertTrue(is_more_precise(TensorType((1, 2, 3)), TensorType((1, Dyn, 3))))
|
||||
self.assertTrue(is_more_precise(int, Dyn))
|
||||
self.assertTrue(is_more_precise(int, int))
|
||||
self.assertFalse(is_more_precise(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5))))
|
||||
self.assertFalse(
|
||||
is_more_precise(TensorType((1, 2, 3)), TensorType((1, 2, 3, 5)))
|
||||
)
|
||||
self.assertFalse(is_more_precise(TensorType((1, 2, 3)), int))
|
||||
|
||||
def test_broadcasting1(self):
|
||||
@ -94,7 +108,10 @@ class AnnotationsTest(TestCase):
|
||||
t5 = TensorType((4, 4, 4))
|
||||
# todo switch all code to use list instead of tuple
|
||||
t6 = TensorType([1])
|
||||
assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, 4)))
|
||||
assert broadcast_types(t1, t2) == (
|
||||
TensorType((1, 2, 3, 4)),
|
||||
TensorType((1, 2, 3, 4)),
|
||||
)
|
||||
assert broadcast_types(t3, t4) == (t4, t4)
|
||||
assert broadcast_types(t5, t6) == (t5, t5)
|
||||
|
||||
@ -102,46 +119,58 @@ class AnnotationsTest(TestCase):
|
||||
t1 = TensorType((2, 3, 4))
|
||||
t2 = TensorType((1, 2, 1, 4))
|
||||
|
||||
assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, 4)), TensorType((1, 2, 3, 4)))
|
||||
assert broadcast_types(t1, t2) == (
|
||||
TensorType((1, 2, 3, 4)),
|
||||
TensorType((1, 2, 3, 4)),
|
||||
)
|
||||
|
||||
def test_broadcasting3(self):
|
||||
t1 = TensorType((1, 2, 3, Dyn))
|
||||
t2 = TensorType((2, 3, 4))
|
||||
assert broadcast_types(t1, t2) == (TensorType((1, 2, 3, Dyn)), TensorType((1, 2, 3, 4)))
|
||||
assert broadcast_types(t1, t2) == (
|
||||
TensorType((1, 2, 3, Dyn)),
|
||||
TensorType((1, 2, 3, 4)),
|
||||
)
|
||||
|
||||
|
||||
class TypeCheckerTest(TestCase):
|
||||
|
||||
def test_type_check_add_with_broadcast(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))):
|
||||
return torch.add(x, y)
|
||||
|
||||
module = M()
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
tc.type_check()
|
||||
expected_ph_types = [TensorType((1, 2, 3, Dyn)),
|
||||
TensorType((2, 3, 4)),
|
||||
TensorType((1, 2, 3, Dyn)),
|
||||
TensorType((1, 2, 3, Dyn))]
|
||||
expected_ph_types = [
|
||||
TensorType((1, 2, 3, Dyn)),
|
||||
TensorType((2, 3, 4)),
|
||||
TensorType((1, 2, 3, Dyn)),
|
||||
TensorType((1, 2, 3, Dyn)),
|
||||
]
|
||||
expected_iter = iter(expected_ph_types)
|
||||
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'call_function':
|
||||
assert n.meta['broadcast']
|
||||
if n.op == "call_function":
|
||||
assert n.meta["broadcast"]
|
||||
assert n.type == next(expected_iter)
|
||||
|
||||
def test_type_check_add_with_scalar(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: int, y: TensorType((2, 3, 4))):
|
||||
return torch.add(x, y)
|
||||
|
||||
module = M()
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
tc.type_check()
|
||||
expected_ph_types = [int,
|
||||
TensorType((2, 3, 4)),
|
||||
TensorType((2, 3, 4)),
|
||||
TensorType((2, 3, 4))]
|
||||
expected_ph_types = [
|
||||
int,
|
||||
TensorType((2, 3, 4)),
|
||||
TensorType((2, 3, 4)),
|
||||
TensorType((2, 3, 4)),
|
||||
]
|
||||
expected_iter = iter(expected_ph_types)
|
||||
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
@ -151,6 +180,7 @@ class TypeCheckerTest(TestCase):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((1, 2, 3))):
|
||||
return torch.add(x, y)
|
||||
|
||||
module = M()
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
@ -161,6 +191,7 @@ class TypeCheckerTest(TestCase):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: TensorType((1, 2, Dyn)), y: TensorType((1, 2, 3))):
|
||||
return torch.add(x, y)
|
||||
|
||||
module = M()
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
@ -170,9 +201,9 @@ class TypeCheckerTest(TestCase):
|
||||
expected_iter = iter(expected_ph_types)
|
||||
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == next(expected_iter)
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType((1, 2, Dyn))
|
||||
|
||||
def test_type_check_reshape_true(self):
|
||||
@ -186,13 +217,13 @@ class TypeCheckerTest(TestCase):
|
||||
self.assertTrue(tc.type_check())
|
||||
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == TensorType((1, 6))
|
||||
|
||||
if n.op == 'call_function':
|
||||
if n.op == "call_function":
|
||||
assert n.type == TensorType((1, 2, 3))
|
||||
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType((1, 2, 3))
|
||||
|
||||
def test_type_check_reshape_false(self):
|
||||
@ -244,16 +275,16 @@ class TypeCheckerTest(TestCase):
|
||||
return torch.transpose(x, 0, 1)
|
||||
|
||||
module = M()
|
||||
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
self.assertTrue(tc.type_check())
|
||||
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'call_function':
|
||||
if n.op == "call_function":
|
||||
assert n.type == TensorType([2, 1, 3, 5])
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType([2, 1, 3, 5])
|
||||
if n.op == 'x':
|
||||
if n.op == "x":
|
||||
assert n.placeholder == TensorType([1, 2, 3, 5])
|
||||
|
||||
def test_type_check_transpose_False(self):
|
||||
@ -269,7 +300,6 @@ class TypeCheckerTest(TestCase):
|
||||
|
||||
def test_type_check_batch_norm_2D(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes):
|
||||
super().__init__()
|
||||
norm_layer = torch.nn.BatchNorm2d
|
||||
@ -289,18 +319,17 @@ class TypeCheckerTest(TestCase):
|
||||
tc.type_check()
|
||||
|
||||
for n in graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == TensorType((2, 2, 5, 4))
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType((2, 2, 5, 4))
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
assert n.type == TensorType((2, 2, 5, 4))
|
||||
if n.op == 'call_function':
|
||||
if n.op == "call_function":
|
||||
assert n.type == TensorType((2, 2, 5, 4))
|
||||
|
||||
def test_type_check_batch_norm_2D_false(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes):
|
||||
super().__init__()
|
||||
norm_layer = torch.nn.BatchNorm2d
|
||||
@ -322,7 +351,6 @@ class TypeCheckerTest(TestCase):
|
||||
|
||||
def test_type_check_batch_norm_2D_broadcast(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes):
|
||||
super().__init__()
|
||||
norm_layer = torch.nn.BatchNorm2d
|
||||
@ -341,13 +369,13 @@ class TypeCheckerTest(TestCase):
|
||||
tc = GraphTypeChecker({}, traced)
|
||||
tc.type_check()
|
||||
for n in graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
|
||||
if n.op == 'call_function':
|
||||
if n.op == "call_function":
|
||||
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
assert n.type == TensorType((2, 2, Dyn, 4))
|
||||
|
||||
B = BasicBlock(1, 1)
|
||||
@ -379,13 +407,13 @@ class TypeCheckerTest(TestCase):
|
||||
tc = GraphTypeChecker({}, traced)
|
||||
tc.type_check()
|
||||
for n in graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
|
||||
if n.op == 'call_function':
|
||||
if n.op == "call_function":
|
||||
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType((Dyn, Dyn, Dyn, Dyn))
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
assert n.type == TensorType((2, 2, Dyn, 4))
|
||||
|
||||
def test_type_check_conv2D_2(self):
|
||||
@ -412,13 +440,13 @@ class TypeCheckerTest(TestCase):
|
||||
tc.type_check()
|
||||
t = TensorType((5, 2, 3, 4))
|
||||
for n in graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == t
|
||||
if n.op == 'call_function':
|
||||
if n.op == "call_function":
|
||||
assert n.type == t
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert torch.Size(n.type.__args__) == b.shape
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
assert n.type == t
|
||||
|
||||
B = BasicBlock(1, 2)
|
||||
@ -430,12 +458,27 @@ class TypeCheckerTest(TestCase):
|
||||
tc.type_check()
|
||||
|
||||
def test_type_check_conv2D_2_fully_static(self):
|
||||
annotation_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
|
||||
(10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 3)]
|
||||
input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
|
||||
(10, 15, 13, 14), (1, 2, 2, 3)]
|
||||
intermediate_types = [(1, Dyn, Dyn, 7), (2, Dyn, 4, 6), (10, 15, Dyn, 5),
|
||||
(10, 15, 7, 7), (1, Dyn, Dyn, Dyn)]
|
||||
annotation_list = [
|
||||
(1, 2, 3, 5),
|
||||
(2, 5, 6, 9),
|
||||
(10, 15, 13, 14),
|
||||
(10, Dyn, 13, 14),
|
||||
(Dyn, Dyn, Dyn, 3),
|
||||
]
|
||||
input_list = [
|
||||
(1, 2, 3, 5),
|
||||
(2, 5, 6, 9),
|
||||
(10, 15, 13, 14),
|
||||
(10, 15, 13, 14),
|
||||
(1, 2, 2, 3),
|
||||
]
|
||||
intermediate_types = [
|
||||
(1, Dyn, Dyn, 7),
|
||||
(2, Dyn, 4, 6),
|
||||
(10, 15, Dyn, 5),
|
||||
(10, 15, 7, 7),
|
||||
(1, Dyn, Dyn, Dyn),
|
||||
]
|
||||
in_planes_list = [2, 5, 15, 15, 2]
|
||||
stride_list = [1, 2, 3, 2, 2]
|
||||
out_planes_list = [2, 5, 15, 15, 2]
|
||||
@ -443,7 +486,13 @@ class TypeCheckerTest(TestCase):
|
||||
dilation_list = [1, 2, 3, 3, 3]
|
||||
padding_list = [1, 2, 3, 3, 3]
|
||||
kernel_size_list = [1, 2, 3, 3, 3]
|
||||
output_types = [(1, 2, Dyn, 7), (2, 5, 4, 6), (10, 15, Dyn, 5), (10, 15, 7, 7), (1, 2, Dyn, Dyn)]
|
||||
output_types = [
|
||||
(1, 2, Dyn, 7),
|
||||
(2, 5, 4, 6),
|
||||
(10, 15, Dyn, 5),
|
||||
(10, 15, 7, 7),
|
||||
(1, 2, Dyn, Dyn),
|
||||
]
|
||||
|
||||
for i in range(5):
|
||||
annotation = annotation_list[i]
|
||||
@ -458,24 +507,42 @@ class TypeCheckerTest(TestCase):
|
||||
intermediate_type = intermediate_types[i]
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation):
|
||||
def __init__(
|
||||
self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups,
|
||||
dilation,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, groups=groups, bias=False, dilation=dilation)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels=in_planes,
|
||||
out_channels=out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
return out
|
||||
|
||||
B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation)
|
||||
B = BasicBlock(
|
||||
in_planes, out_planes, kernel_size, stride, padding, groups, dilation
|
||||
)
|
||||
ast_rewriter = RewritingTracer()
|
||||
graph = ast_rewriter.trace(B)
|
||||
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
||||
|
||||
# annotate our argument
|
||||
for n in graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
n.type = TensorType(annotation)
|
||||
|
||||
b = B.forward(torch.rand(input))
|
||||
@ -483,36 +550,54 @@ class TypeCheckerTest(TestCase):
|
||||
tc.type_check()
|
||||
|
||||
for n in graph.nodes:
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert is_consistent(n.type, TensorType(b.size()))
|
||||
|
||||
# test with intermediate annotations
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding, groups, dilation):
|
||||
def __init__(
|
||||
self,
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups,
|
||||
dilation,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(in_channels=in_planes, out_channels=out_planes,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, groups=groups, bias=False, dilation=dilation)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels=in_planes,
|
||||
out_channels=out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv1(x)
|
||||
return out
|
||||
|
||||
B = BasicBlock(in_planes, out_planes, kernel_size, stride, padding, groups, dilation)
|
||||
B = BasicBlock(
|
||||
in_planes, out_planes, kernel_size, stride, padding, groups, dilation
|
||||
)
|
||||
ast_rewriter = RewritingTracer()
|
||||
graph = ast_rewriter.trace(B)
|
||||
traced = GraphModule(ast_rewriter.root, graph, "gm")
|
||||
|
||||
# populate our intermediate notes
|
||||
for n in traced.graph.nodes:
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
n.type = TensorType(intermediate_type)
|
||||
|
||||
tc = GraphTypeChecker({}, traced)
|
||||
tc.type_check()
|
||||
|
||||
for n in traced.graph.nodes:
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType(output_types[i])
|
||||
assert is_consistent(n.type, TensorType(b.size()))
|
||||
|
||||
@ -520,14 +605,26 @@ class TypeCheckerTest(TestCase):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
||||
base_width=64, dilation=1):
|
||||
def __init__(
|
||||
self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
downsample=None,
|
||||
groups=1,
|
||||
base_width=64,
|
||||
dilation=1,
|
||||
):
|
||||
super().__init__()
|
||||
norm_layer = torch.nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
raise ValueError(
|
||||
"BasicBlock only supports groups=1 and base_width=64"
|
||||
)
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
raise NotImplementedError(
|
||||
"Dilation > 1 not supported in BasicBlock"
|
||||
)
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
@ -565,12 +662,14 @@ class TypeCheckerTest(TestCase):
|
||||
tc.type_check()
|
||||
|
||||
for n in traced.graph.nodes:
|
||||
if n.target == 'output':
|
||||
if n.target == "output":
|
||||
assert isinstance(n.type, TensorType)
|
||||
assert torch.Size(n.type.__args__) == B.forward(torch.rand(2, 2, 4, 5)).size()
|
||||
assert (
|
||||
torch.Size(n.type.__args__)
|
||||
== B.forward(torch.rand(2, 2, 4, 5)).size()
|
||||
)
|
||||
|
||||
def test_type_check_conv2D_maxpool2d_flatten(self):
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -581,7 +680,7 @@ class TypeCheckerTest(TestCase):
|
||||
self.fc1 = torch.nn.Linear(5, 120)
|
||||
self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))
|
||||
|
||||
def forward(self, x : TensorType((4, 3, 32, 32))):
|
||||
def forward(self, x: TensorType((4, 3, 32, 32))):
|
||||
out = self.conv1(x)
|
||||
out = self.pool(out)
|
||||
out = self.conv2(out)
|
||||
@ -598,10 +697,17 @@ class TypeCheckerTest(TestCase):
|
||||
tc = GraphTypeChecker({}, traced)
|
||||
tc.type_check()
|
||||
|
||||
expected_ph_types = [TensorType((4, 3, 32, 32)), TensorType((4, 6, 28, 28)),
|
||||
TensorType((4, 6, 14, 14)), TensorType((4, 16, 10, 10)),
|
||||
TensorType((4, 16, 5, 5)), TensorType((4, 16, 5, 120)),
|
||||
TensorType((4, 16, 6, 7)), TensorType((4, 672)), TensorType((4, 672))]
|
||||
expected_ph_types = [
|
||||
TensorType((4, 3, 32, 32)),
|
||||
TensorType((4, 6, 28, 28)),
|
||||
TensorType((4, 6, 14, 14)),
|
||||
TensorType((4, 16, 10, 10)),
|
||||
TensorType((4, 16, 5, 5)),
|
||||
TensorType((4, 16, 5, 120)),
|
||||
TensorType((4, 16, 6, 7)),
|
||||
TensorType((4, 672)),
|
||||
TensorType((4, 672)),
|
||||
]
|
||||
|
||||
expected_iter = iter(expected_ph_types)
|
||||
traced.graph.eliminate_dead_code()
|
||||
@ -619,10 +725,9 @@ class TypeCheckerTest(TestCase):
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
tc.type_check()
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType((1, 6, 5, Dyn))
|
||||
|
||||
|
||||
def test_type_check_flatten_2(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: TensorType((1, Dyn, 3, 5, Dyn))):
|
||||
@ -633,7 +738,7 @@ class TypeCheckerTest(TestCase):
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
tc.type_check()
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType((1, Dyn, 5, Dyn))
|
||||
|
||||
def test_type_check_flatten3(self):
|
||||
@ -646,7 +751,7 @@ class TypeCheckerTest(TestCase):
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
tc.type_check()
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType((2, 60))
|
||||
r = Refine(symbolic_traced)
|
||||
r.refine()
|
||||
@ -654,13 +759,12 @@ class TypeCheckerTest(TestCase):
|
||||
assert c == [Equality(2, 2)]
|
||||
|
||||
def test_type_typechecl_maxpool2d_3dinput(self):
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool = torch.nn.MaxPool2d(5, 8)
|
||||
|
||||
def forward(self, x : TensorType((64, 8, 8))):
|
||||
def forward(self, x: TensorType((64, 8, 8))):
|
||||
out = self.pool(x)
|
||||
return out
|
||||
|
||||
@ -672,21 +776,42 @@ class TypeCheckerTest(TestCase):
|
||||
tc.type_check()
|
||||
|
||||
for n in traced.graph.nodes:
|
||||
if n.target == 'output':
|
||||
if n.target == "output":
|
||||
assert n.type == TensorType((64, 1, 1))
|
||||
|
||||
def test_type_maxpool2d_fully_static(self):
|
||||
annotation_list = [(Dyn, Dyn, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
|
||||
(10, Dyn, 13, 14), (Dyn, Dyn, Dyn, 10)]
|
||||
input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
|
||||
(10, 15, 13, 14), (2, 2, 10, 10)]
|
||||
intermediate_types = [(1, 2, Dyn, Dyn), (2, Dyn, 2, 4), (10, 15, Dyn, 2),
|
||||
(10, 15, 2, 3), (2, Dyn, Dyn, Dyn)]
|
||||
annotation_list = [
|
||||
(Dyn, Dyn, 3, 5),
|
||||
(2, 5, 6, 9),
|
||||
(10, 15, 13, 14),
|
||||
(10, Dyn, 13, 14),
|
||||
(Dyn, Dyn, Dyn, 10),
|
||||
]
|
||||
input_list = [
|
||||
(1, 2, 3, 5),
|
||||
(2, 5, 6, 9),
|
||||
(10, 15, 13, 14),
|
||||
(10, 15, 13, 14),
|
||||
(2, 2, 10, 10),
|
||||
]
|
||||
intermediate_types = [
|
||||
(1, 2, Dyn, Dyn),
|
||||
(2, Dyn, 2, 4),
|
||||
(10, 15, Dyn, 2),
|
||||
(10, 15, 2, 3),
|
||||
(2, Dyn, Dyn, Dyn),
|
||||
]
|
||||
stride_list = [1, 2, 3, 2, 1]
|
||||
dilation_list = [1, 2, 3, 3, 2]
|
||||
padding_list = [1, 2, 3, 3, 1]
|
||||
kernel_size_list = [2, 4, 6, 6, 3]
|
||||
output_types = [(1, 2, 4, 6), (2, 5, 2, 4), (10, 15, 2, 2), (10, 15, 2, 3), (2, Dyn, Dyn, 8)]
|
||||
output_types = [
|
||||
(1, 2, 4, 6),
|
||||
(2, 5, 2, 4),
|
||||
(10, 15, 2, 2),
|
||||
(10, 15, 2, 3),
|
||||
(2, Dyn, Dyn, 8),
|
||||
]
|
||||
|
||||
for i in range(5):
|
||||
annotation = annotation_list[i]
|
||||
@ -700,9 +825,14 @@ class TypeCheckerTest(TestCase):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self, kernel_size, stride, padding, dilation):
|
||||
super().__init__()
|
||||
self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation,
|
||||
return_indices=False, ceil_mode=False)
|
||||
self.pool = torch.nn.MaxPool2d(
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
return_indices=False,
|
||||
ceil_mode=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.pool(x)
|
||||
@ -715,7 +845,7 @@ class TypeCheckerTest(TestCase):
|
||||
|
||||
# annotate our argument
|
||||
for n in graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
n.type = TensorType(annotation)
|
||||
|
||||
b = B.forward(torch.rand(input))
|
||||
@ -723,16 +853,21 @@ class TypeCheckerTest(TestCase):
|
||||
tc.type_check()
|
||||
|
||||
for n in graph.nodes:
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert is_consistent(n.type, TensorType(b.size()))
|
||||
|
||||
# test with intermediate annotations
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self, kernel_size, stride, padding, dilation):
|
||||
super().__init__()
|
||||
self.pool = torch.nn.MaxPool2d(kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation,
|
||||
return_indices=False, ceil_mode=False)
|
||||
self.pool = torch.nn.MaxPool2d(
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
return_indices=False,
|
||||
ceil_mode=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.pool(x)
|
||||
@ -745,30 +880,45 @@ class TypeCheckerTest(TestCase):
|
||||
|
||||
# annotate our argument
|
||||
for n in graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
n.type = TensorType(annotation)
|
||||
|
||||
# populate our intermediate notes
|
||||
for n in traced.graph.nodes:
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
n.type = TensorType(intermediate_type)
|
||||
|
||||
tc = GraphTypeChecker({}, traced)
|
||||
tc.type_check()
|
||||
|
||||
for n in traced.graph.nodes:
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert n.type == TensorType(output_types[i])
|
||||
assert is_consistent(n.type, TensorType(b.size()))
|
||||
|
||||
def test_flatten_fully_static(self):
|
||||
annotation_list = [Dyn, TensorType((2, 5, 6, 9)), TensorType((10, 15, 13, 14)),
|
||||
TensorType((10, Dyn, 13, 14)), TensorType((Dyn, Dyn, Dyn, 10))]
|
||||
input_list = [(1, 2, 3, 5), (2, 5, 6, 9), (10, 15, 13, 14),
|
||||
(10, 15, 13, 14), (2, 2, 10, 10)]
|
||||
annotation_list = [
|
||||
Dyn,
|
||||
TensorType((2, 5, 6, 9)),
|
||||
TensorType((10, 15, 13, 14)),
|
||||
TensorType((10, Dyn, 13, 14)),
|
||||
TensorType((Dyn, Dyn, Dyn, 10)),
|
||||
]
|
||||
input_list = [
|
||||
(1, 2, 3, 5),
|
||||
(2, 5, 6, 9),
|
||||
(10, 15, 13, 14),
|
||||
(10, 15, 13, 14),
|
||||
(2, 2, 10, 10),
|
||||
]
|
||||
|
||||
intermediate_list = [Dyn, (2, 5, 6, 9), (10, 15, 13, 14),
|
||||
(10, 15, 13, 14), (2, 2, 10, 10)]
|
||||
intermediate_list = [
|
||||
Dyn,
|
||||
(2, 5, 6, 9),
|
||||
(10, 15, 13, 14),
|
||||
(10, 15, 13, 14),
|
||||
(2, 2, 10, 10),
|
||||
]
|
||||
|
||||
start_dim = [1, 2, 1, 2, 0]
|
||||
end_dim = [1, 3, 3, 3, -2]
|
||||
@ -795,7 +945,7 @@ class TypeCheckerTest(TestCase):
|
||||
|
||||
# annotate our argument
|
||||
for n in graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
n.type = annotation
|
||||
|
||||
b = B.forward(torch.rand(input))
|
||||
@ -803,7 +953,7 @@ class TypeCheckerTest(TestCase):
|
||||
tc.type_check()
|
||||
|
||||
for n in graph.nodes:
|
||||
if n.op == 'output':
|
||||
if n.op == "output":
|
||||
assert is_consistent(n.type, TensorType(b.size()))
|
||||
|
||||
@skipIfNoTorchVision
|
||||
@ -825,36 +975,34 @@ class TypeCheckerTest(TestCase):
|
||||
gm_run.graph.eliminate_dead_code()
|
||||
# here we are checking for consistency with fully dynamic nodes
|
||||
for n1, n2 in zip(gm_static.graph.nodes, gm_run.graph.nodes):
|
||||
assert is_consistent(n1.type, TensorType(n2.meta['tensor_meta'].shape))
|
||||
assert is_consistent(n1.type, TensorType(n2.meta["tensor_meta"].shape))
|
||||
|
||||
# here we give the same input as to runtime
|
||||
gm_static_with_types = symbolic_trace(resnet50())
|
||||
|
||||
# we initialize our placeholder
|
||||
for n in gm_static_with_types.graph.nodes:
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
n.type = TensorType((1, 3, 224, 224))
|
||||
|
||||
g = GraphTypeChecker({}, gm_static_with_types)
|
||||
g.type_check()
|
||||
for n1, n2 in zip(gm_static_with_types.graph.nodes, gm_run.graph.nodes):
|
||||
assert n1.type == TensorType(n2.meta['tensor_meta'].shape)
|
||||
assert n1.type == TensorType(n2.meta["tensor_meta"].shape)
|
||||
|
||||
# apply shape inference to graph and check
|
||||
# that the batch size is equal across all layers
|
||||
infer_symbolic_types(gm_static)
|
||||
|
||||
|
||||
batch_sizes = set()
|
||||
gm_static.graph.eliminate_dead_code()
|
||||
for n in gm_static.graph.nodes:
|
||||
assert isinstance(n.type, TensorType)
|
||||
batch_sizes.add(n.type.__args__[0])
|
||||
assert (len(batch_sizes) == 1)
|
||||
assert len(batch_sizes) == 1
|
||||
|
||||
def test_type_check_batch_norm_symbolic(self):
|
||||
class BasicBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes):
|
||||
super().__init__()
|
||||
norm_layer = torch.nn.BatchNorm2d
|
||||
@ -875,10 +1023,14 @@ class TypeCheckerTest(TestCase):
|
||||
|
||||
infer_symbolic_types(traced)
|
||||
|
||||
my_types = iter([TensorType[(2, 2, sympy.symbols('~7'), 4)],
|
||||
TensorType[(2, 2, sympy.symbols('~7'), 4)],
|
||||
TensorType[(2, 2, sympy.symbols('~7'), 4)],
|
||||
TensorType[(2, 2, sympy.symbols('~7'), 4)]])
|
||||
my_types = iter(
|
||||
[
|
||||
TensorType[(2, 2, sympy.symbols("~7"), 4)],
|
||||
TensorType[(2, 2, sympy.symbols("~7"), 4)],
|
||||
TensorType[(2, 2, sympy.symbols("~7"), 4)],
|
||||
TensorType[(2, 2, sympy.symbols("~7"), 4)],
|
||||
]
|
||||
)
|
||||
|
||||
for n in graph.nodes:
|
||||
assert n.type == next(my_types)
|
||||
@ -887,6 +1039,7 @@ class TypeCheckerTest(TestCase):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: TensorType((2, 3, 4))):
|
||||
return torch.add(x, y)
|
||||
|
||||
module = M()
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
@ -901,13 +1054,14 @@ class TypeCheckerTest(TestCase):
|
||||
|
||||
infer_symbolic_types(symbolic_traced)
|
||||
|
||||
expected_ph_types = [TensorType((1, 2, 3, sympy.symbols('~0'))),
|
||||
TensorType((2, 3, 4)),
|
||||
TensorType((1, 2, 3, sympy.symbols('~1'))),
|
||||
TensorType((1, 2, 3, sympy.symbols('~1')))]
|
||||
expected_ph_types = [
|
||||
TensorType((1, 2, 3, sympy.symbols("~0"))),
|
||||
TensorType((2, 3, 4)),
|
||||
TensorType((1, 2, 3, sympy.symbols("~1"))),
|
||||
TensorType((1, 2, 3, sympy.symbols("~1"))),
|
||||
]
|
||||
expected_iter = iter(expected_ph_types)
|
||||
|
||||
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
assert n.type == next(expected_iter)
|
||||
|
||||
@ -915,6 +1069,7 @@ class TypeCheckerTest(TestCase):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x: TensorType((1, 2)), y: TensorType((Dyn, 2))):
|
||||
return torch.add(x, y)
|
||||
|
||||
module = M()
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
tc = GraphTypeChecker({}, symbolic_traced)
|
||||
@ -923,10 +1078,12 @@ class TypeCheckerTest(TestCase):
|
||||
r = Refine(symbolic_traced)
|
||||
r.refine()
|
||||
|
||||
expected_ph_types = [TensorType((1, 2)),
|
||||
TensorType((sympy.symbols('~1'), 2)),
|
||||
TensorType((sympy.symbols('~1'), 2)),
|
||||
TensorType((sympy.symbols('~1'), 2))]
|
||||
expected_ph_types = [
|
||||
TensorType((1, 2)),
|
||||
TensorType((sympy.symbols("~1"), 2)),
|
||||
TensorType((sympy.symbols("~1"), 2)),
|
||||
TensorType((sympy.symbols("~1"), 2)),
|
||||
]
|
||||
expected_iter = iter(expected_ph_types)
|
||||
|
||||
for n in symbolic_traced.graph.nodes:
|
||||
@ -955,12 +1112,11 @@ class TypeCheckerTest(TestCase):
|
||||
infer_symbolic_types(traced)
|
||||
|
||||
for n in traced.graph.nodes:
|
||||
if n.op == 'call_module':
|
||||
if n.op == "call_module":
|
||||
assert isinstance(n.type.__args__[2], sympy.floor)
|
||||
assert isinstance(n.type.__args__[3], sympy.floor)
|
||||
|
||||
def test_type_check_symbolic_inferenceconv2D_maxpool2d_flatten(self):
|
||||
|
||||
class BasicBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -971,7 +1127,7 @@ class TypeCheckerTest(TestCase):
|
||||
self.fc1 = torch.nn.Linear(5, 120)
|
||||
self.pool2 = torch.nn.AdaptiveAvgPool2d((6, 7))
|
||||
|
||||
def forward(self, x : TensorType((4, 3, Dyn, Dyn))):
|
||||
def forward(self, x: TensorType((4, 3, Dyn, Dyn))):
|
||||
out = self.conv1(x)
|
||||
out = self.pool(out)
|
||||
out = self.conv2(out)
|
||||
@ -989,13 +1145,26 @@ class TypeCheckerTest(TestCase):
|
||||
infer_symbolic_types(traced)
|
||||
|
||||
for n in traced.graph.nodes:
|
||||
if n.target == 'conv1':
|
||||
assert n.type == TensorType((4, 6, sympy.floor(sympy.symbols('~0') - 4),
|
||||
sympy.floor(sympy.symbols('~1') - 4)))
|
||||
if n.target == "conv1":
|
||||
assert n.type == TensorType(
|
||||
(
|
||||
4,
|
||||
6,
|
||||
sympy.floor(sympy.symbols("~0") - 4),
|
||||
sympy.floor(sympy.symbols("~1") - 4),
|
||||
)
|
||||
)
|
||||
|
||||
elif n.target == 'conv2':
|
||||
assert n.type == TensorType((4, 16, sympy.floor(sympy.symbols('~4') - 4),
|
||||
sympy.floor(sympy.symbols('~5') - 4)))
|
||||
elif n.target == "conv2":
|
||||
assert n.type == TensorType(
|
||||
(
|
||||
4,
|
||||
16,
|
||||
sympy.floor(sympy.symbols("~4") - 4),
|
||||
sympy.floor(sympy.symbols("~5") - 4),
|
||||
)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -69,10 +69,12 @@ class TestMatcher(JitTestCase):
|
||||
def test_subgraph_matcher_with_list(self):
|
||||
def original(x, y):
|
||||
return torch.ops.aten.view(x, [5, y.shape[0]])
|
||||
|
||||
original_graph = torch.fx.symbolic_trace(original).graph
|
||||
|
||||
def pattern(x, y, z):
|
||||
return torch.ops.aten.view(x, [z, y.shape[0]])
|
||||
|
||||
pattern_graph = torch.fx.symbolic_trace(pattern).graph
|
||||
|
||||
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
||||
@ -81,11 +83,17 @@ class TestMatcher(JitTestCase):
|
||||
|
||||
def test_subgraph_matcher_with_list_bad(self):
|
||||
def original(x, y):
|
||||
return torch.ops.aten._reshape_alias_copy.default(x, [1, y.shape[0]], [y.shape[1], y.shape[1]])
|
||||
return torch.ops.aten._reshape_alias_copy.default(
|
||||
x, [1, y.shape[0]], [y.shape[1], y.shape[1]]
|
||||
)
|
||||
|
||||
original_graph = torch.fx.symbolic_trace(original).graph
|
||||
|
||||
def pattern(x, y, b):
|
||||
return torch.ops.aten._reshape_alias_copy.default(x, [b, y.shape[0], y.shape[1]], [y.shape[1]])
|
||||
return torch.ops.aten._reshape_alias_copy.default(
|
||||
x, [b, y.shape[0], y.shape[1]], [y.shape[1]]
|
||||
)
|
||||
|
||||
pattern_graph = torch.fx.symbolic_trace(pattern).graph
|
||||
|
||||
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
||||
@ -101,6 +109,7 @@ class TestMatcher(JitTestCase):
|
||||
|
||||
def pattern(x):
|
||||
return x + 2
|
||||
|
||||
pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph
|
||||
pattern_graph.eliminate_dead_code()
|
||||
|
||||
@ -116,7 +125,10 @@ class TestMatcher(JitTestCase):
|
||||
inputs = (torch.randn(20, 16, 50, 32),)
|
||||
|
||||
def maxpool(x, kernel_size, stride, padding, dilation):
|
||||
return torch.ops.aten.max_pool2d_with_indices.default(x, kernel_size, stride, padding, dilation)
|
||||
return torch.ops.aten.max_pool2d_with_indices.default(
|
||||
x, kernel_size, stride, padding, dilation
|
||||
)
|
||||
|
||||
maxpool_graph = torch.fx.symbolic_trace(maxpool).graph
|
||||
|
||||
maxpool_matcher = SubgraphMatcher(maxpool_graph)
|
||||
@ -144,7 +156,9 @@ class TestMatcher(JitTestCase):
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
||||
def test_split_to_graph_and_name_node_map(self):
|
||||
"""Testing the internal helper function for splitting the pattern graph"""
|
||||
from torch.fx.passes.utils.matcher_with_name_node_map_utils import _split_to_graph_and_name_node_map
|
||||
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
|
||||
_split_to_graph_and_name_node_map,
|
||||
)
|
||||
|
||||
def pattern(x, weight):
|
||||
conv = F.conv2d(x, weight)
|
||||
@ -153,6 +167,7 @@ class TestMatcher(JitTestCase):
|
||||
return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
|
||||
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(1, 3, 3, 3) * 10,
|
||||
torch.randn(3, 3, 3, 3),
|
||||
@ -166,8 +181,7 @@ class TestMatcher(JitTestCase):
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
||||
def test_matcher_with_name_node_map_function(self):
|
||||
"""Testing SubgraphMatcherWithNameNodeMap with function pattern
|
||||
"""
|
||||
"""Testing SubgraphMatcherWithNameNodeMap with function pattern"""
|
||||
|
||||
def target_graph(x, weight):
|
||||
x = x * 2
|
||||
@ -184,13 +198,16 @@ class TestMatcher(JitTestCase):
|
||||
return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
|
||||
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(1, 3, 3, 3) * 10,
|
||||
torch.randn(3, 3, 3, 3),
|
||||
)
|
||||
pattern_gm = capture_pre_autograd_graph(WrapperModule(pattern), example_inputs)
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
||||
target_gm = capture_pre_autograd_graph(WrapperModule(target_graph), example_inputs)
|
||||
target_gm = capture_pre_autograd_graph(
|
||||
WrapperModule(target_graph), example_inputs
|
||||
)
|
||||
internal_matches = matcher.match(target_gm.graph)
|
||||
for internal_match in internal_matches:
|
||||
name_node_map = internal_match.name_node_map
|
||||
@ -200,12 +217,14 @@ class TestMatcher(JitTestCase):
|
||||
# check if we correctly annotated the target graph module
|
||||
for n in target_gm.graph.nodes:
|
||||
if n == name_node_map["conv"]:
|
||||
assert "custom_annotation" in n.meta and n.meta["custom_annotation"] == "annotation"
|
||||
assert (
|
||||
"custom_annotation" in n.meta
|
||||
and n.meta["custom_annotation"] == "annotation"
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
||||
def test_matcher_with_name_node_map_module(self):
|
||||
"""Testing SubgraphMatcherWithNameNodeMap with module pattern
|
||||
"""
|
||||
"""Testing SubgraphMatcherWithNameNodeMap with module pattern"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -215,7 +234,6 @@ class TestMatcher(JitTestCase):
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class Pattern(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -228,9 +246,8 @@ class TestMatcher(JitTestCase):
|
||||
return linear, {"linear": linear, "x": x}
|
||||
|
||||
from torch._export import capture_pre_autograd_graph
|
||||
example_inputs = (
|
||||
torch.randn(3, 5),
|
||||
)
|
||||
|
||||
example_inputs = (torch.randn(3, 5),)
|
||||
pattern_gm = capture_pre_autograd_graph(Pattern(), example_inputs)
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
||||
target_gm = capture_pre_autograd_graph(M(), example_inputs)
|
||||
@ -243,7 +260,11 @@ class TestMatcher(JitTestCase):
|
||||
# check if we correctly annotated the target graph module
|
||||
for n in target_gm.graph.nodes:
|
||||
if n == name_node_map["linear"]:
|
||||
assert "custom_annotation" in n.meta and n.meta["custom_annotation"] == "annotation"
|
||||
assert (
|
||||
"custom_annotation" in n.meta
|
||||
and n.meta["custom_annotation"] == "annotation"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -9,9 +9,13 @@ import torch
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions, check_subgraphs_connected
|
||||
from torch.fx.passes.utils.source_matcher_utils import (
|
||||
check_subgraphs_connected,
|
||||
get_source_partitions,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
class TestSourceMatcher(JitTestCase):
|
||||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
||||
def test_module_partitioner_linear_relu_linear(self):
|
||||
@ -33,15 +37,32 @@ class TestSourceMatcher(JitTestCase):
|
||||
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, [torch.nn.Linear, torch.nn.ReLU])
|
||||
module_partitions = get_source_partitions(
|
||||
gm.graph, [torch.nn.Linear, torch.nn.ReLU]
|
||||
)
|
||||
|
||||
self.assertEqual(len(module_partitions), 2)
|
||||
self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
|
||||
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
|
||||
|
||||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][0], module_partitions[torch.nn.ReLU][0]))
|
||||
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Linear][1], module_partitions[torch.nn.ReLU][0]))
|
||||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Linear][2], module_partitions[torch.nn.ReLU][0]))
|
||||
self.assertFalse(
|
||||
check_subgraphs_connected(
|
||||
module_partitions[torch.nn.Linear][0],
|
||||
module_partitions[torch.nn.ReLU][0],
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
check_subgraphs_connected(
|
||||
module_partitions[torch.nn.Linear][1],
|
||||
module_partitions[torch.nn.ReLU][0],
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
check_subgraphs_connected(
|
||||
module_partitions[torch.nn.Linear][2],
|
||||
module_partitions[torch.nn.ReLU][0],
|
||||
)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
||||
def test_module_partitioner_conv_relu_maxpool(self):
|
||||
@ -69,21 +90,50 @@ class TestSourceMatcher(JitTestCase):
|
||||
return self.maxpool(self.relu(z))
|
||||
|
||||
inputs = (torch.randn(1, 3, 256, 256),)
|
||||
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(*inputs)
|
||||
gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(
|
||||
*inputs
|
||||
)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d])
|
||||
module_partitions = get_source_partitions(
|
||||
gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d]
|
||||
)
|
||||
|
||||
self.assertEqual(len(module_partitions), 3)
|
||||
self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
|
||||
self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
|
||||
self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)
|
||||
|
||||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][0], module_partitions[torch.nn.ReLU][0]))
|
||||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][1], module_partitions[torch.nn.ReLU][0]))
|
||||
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.Conv2d][2], module_partitions[torch.nn.ReLU][0]))
|
||||
self.assertFalse(check_subgraphs_connected(module_partitions[torch.nn.MaxPool2d][0], module_partitions[torch.nn.ReLU][0]))
|
||||
self.assertTrue(check_subgraphs_connected(module_partitions[torch.nn.ReLU][0], module_partitions[torch.nn.MaxPool2d][0]))
|
||||
self.assertFalse(
|
||||
check_subgraphs_connected(
|
||||
module_partitions[torch.nn.Conv2d][0],
|
||||
module_partitions[torch.nn.ReLU][0],
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
check_subgraphs_connected(
|
||||
module_partitions[torch.nn.Conv2d][1],
|
||||
module_partitions[torch.nn.ReLU][0],
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
check_subgraphs_connected(
|
||||
module_partitions[torch.nn.Conv2d][2],
|
||||
module_partitions[torch.nn.ReLU][0],
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
check_subgraphs_connected(
|
||||
module_partitions[torch.nn.MaxPool2d][0],
|
||||
module_partitions[torch.nn.ReLU][0],
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
check_subgraphs_connected(
|
||||
module_partitions[torch.nn.ReLU][0],
|
||||
module_partitions[torch.nn.MaxPool2d][0],
|
||||
)
|
||||
)
|
||||
|
||||
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
|
||||
def test_module_partitioner_functional_conv_relu_conv(self):
|
||||
@ -96,7 +146,15 @@ class TestSourceMatcher(JitTestCase):
|
||||
self.groups = 1
|
||||
|
||||
def forward(self, x, weight, bias):
|
||||
return torch.nn.functional.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return torch.nn.functional.conv2d(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
self.stride,
|
||||
self.padding,
|
||||
self.dilation,
|
||||
self.groups,
|
||||
)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -114,7 +172,9 @@ class TestSourceMatcher(JitTestCase):
|
||||
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.conv2d])
|
||||
module_partitions = get_source_partitions(
|
||||
gm.graph, [torch.nn.functional.conv2d]
|
||||
)
|
||||
|
||||
self.assertEqual(len(module_partitions), 1)
|
||||
self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)
|
||||
@ -138,7 +198,9 @@ class TestSourceMatcher(JitTestCase):
|
||||
gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
module_partitions = get_source_partitions(gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu])
|
||||
module_partitions = get_source_partitions(
|
||||
gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu]
|
||||
)
|
||||
|
||||
self.assertEqual(len(module_partitions), 2)
|
||||
self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)
|
||||
|
@ -4,8 +4,9 @@ import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.fx import symbolic_trace, subgraph_rewriter
|
||||
from torch.fx import subgraph_rewriter, symbolic_trace
|
||||
from torch.fx.annotate import annotate
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
from torch.fx.experimental.rewriter import RewritingTracer
|
||||
|
||||
@ -13,10 +14,13 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_fx.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError(
|
||||
"This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_fx.py TESTNAME\n\n"
|
||||
"instead."
|
||||
)
|
||||
|
||||
|
||||
@torch.fx.wrap
|
||||
def wrapped_gemm_bias_mul(a, b, bias):
|
||||
@ -24,14 +28,15 @@ def wrapped_gemm_bias_mul(a, b, bias):
|
||||
mul_res = lin_res * a
|
||||
return lin_res, mul_res
|
||||
|
||||
|
||||
@torch.fx.wrap
|
||||
def wrapped_gemm_bias_mul_with_c(a, b, bias, c):
|
||||
lin_res = torch.nn.functional.linear(a, b, bias=bias)
|
||||
mul_res = lin_res * c
|
||||
return lin_res, mul_res
|
||||
|
||||
class TestSubgraphRewriter(JitTestCase):
|
||||
|
||||
class TestSubgraphRewriter(JitTestCase):
|
||||
def test_subgraph_rewriter_preserves_logic(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -110,7 +115,9 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
|
||||
x = torch.randn(1, 5)
|
||||
|
||||
matches = subgraph_rewriter.replace_pattern_with_filters(traced, pattern, replacement, [])
|
||||
matches = subgraph_rewriter.replace_pattern_with_filters(
|
||||
traced, pattern, replacement, []
|
||||
)
|
||||
|
||||
traced.graph.lint()
|
||||
|
||||
@ -297,7 +304,9 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
test_outs = traced.forward(x)
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(self):
|
||||
def test_subgraph_rewriter_pattern_output_pattern_node_can_have_users_that_are_not_matched(
|
||||
self,
|
||||
):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y = torch.relu(x)
|
||||
@ -326,7 +335,9 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
test_outs = traced.forward(x)
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(self):
|
||||
def test_subgraph_rewriter_internal_pattern_nodes_cannot_have_users_that_are_not_matched(
|
||||
self,
|
||||
):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, w1, w2, b1, b2):
|
||||
m0 = torch.cat([w1, w2])
|
||||
@ -385,6 +396,7 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
|
||||
Credit to Jerry Zhang (GitHub: jerryzh168) for this test case
|
||||
"""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -483,7 +495,6 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
self.assertEqual(type(submod), torch.nn.ReLU)
|
||||
|
||||
def test_subgraph_rewriter_annotations_int(self):
|
||||
|
||||
class M1(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
y: int = x
|
||||
@ -500,12 +511,11 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
module = M2()
|
||||
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
|
||||
for n, m in zip(symbolic_traced.graph.nodes, graph.nodes):
|
||||
if n.op == 'placeholder':
|
||||
if n.op == "placeholder":
|
||||
assert n.type == int
|
||||
assert m.type == int
|
||||
|
||||
def test_subgraph_rewriter_replace_consecutive_submodules(self):
|
||||
|
||||
def f(x):
|
||||
x = torch.sigmoid(x)
|
||||
x = torch.sigmoid(x)
|
||||
@ -536,7 +546,6 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_subgraph_rewriter_with_overlapping_matches(self):
|
||||
|
||||
def f(x):
|
||||
x = torch.sigmoid(x)
|
||||
x = torch.sigmoid(x)
|
||||
@ -569,7 +578,6 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_subgraph_rewriter_replace_with_multiple_outputs(self):
|
||||
|
||||
def f(x):
|
||||
y = torch.sigmoid(x)
|
||||
z = torch.relu(x)
|
||||
@ -602,7 +610,6 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_subgraph_rewriter_replace_with_duplicated_outputs(self):
|
||||
|
||||
def f(x1, x2):
|
||||
x = x1 - x2
|
||||
y = torch.sigmoid(x)
|
||||
@ -670,7 +677,6 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_subgraph_rewriter_call_method(self):
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x = x.dequantize()
|
||||
@ -701,7 +707,6 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
self.assertEqual(ref_outs, test_outs)
|
||||
|
||||
def test_subgraph_rewriter_nodes_with_kwargs(self):
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -737,7 +742,6 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
self.assertTrue(found_repalcement_node)
|
||||
|
||||
def test_subgraph_rewriter_local_revert(self):
|
||||
|
||||
# Following model will have 3 anchors as the matching candidate with the given pattern
|
||||
# Anchor 1 and 3 is a real match, but anchor 2 is not.
|
||||
# The subgraph rewriter should be able to revert the changes made while matching anchor 2.
|
||||
@ -763,9 +767,7 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
# potential match at anchor 1
|
||||
mul_res_1 = in1 * lin_res_2
|
||||
sum_res_1 = mul_res_1 + in1
|
||||
lin_res_3 = torch.nn.functional.linear(
|
||||
sum_res_1, self.w2, bias=self.b2
|
||||
)
|
||||
lin_res_3 = torch.nn.functional.linear(sum_res_1, self.w2, bias=self.b2)
|
||||
sigmoid_res_1 = torch.sigmoid(lin_res_3)
|
||||
# potential match at anchor 2
|
||||
mul_res_2 = lin_res_3 * sigmoid_res_1
|
||||
@ -791,9 +793,8 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
|
||||
traced = symbolic_trace(M())
|
||||
matches = subgraph_rewriter.replace_pattern(
|
||||
traced,
|
||||
gemm_bias_mul_pattern_with_c,
|
||||
gemm_bias_mul_replacement_with_c)
|
||||
traced, gemm_bias_mul_pattern_with_c, gemm_bias_mul_replacement_with_c
|
||||
)
|
||||
|
||||
self.assertEqual(len(matches), 2)
|
||||
|
||||
@ -834,7 +835,7 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
return x
|
||||
|
||||
def second_input_is_scalar(match, original_graph, pattern_graph):
|
||||
""" check the node that's matched to the second input of the pattern graph
|
||||
"""check the node that's matched to the second input of the pattern graph
|
||||
is a scalar number
|
||||
"""
|
||||
input_idx = 0
|
||||
@ -848,19 +849,21 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
return True
|
||||
|
||||
def check_replacement_nodes(self, traced, matches):
|
||||
replacement_nodes_in_graph = [node for node in traced.graph.nodes if node.target == torch.mul]
|
||||
replacement_nodes_in_graph = [
|
||||
node for node in traced.graph.nodes if node.target == torch.mul
|
||||
]
|
||||
replacement_nodes_in_res = [r for m in matches for r in m.replacements]
|
||||
self.assertEqual(len(replacement_nodes_in_graph), len(replacement_nodes_in_res))
|
||||
self.assertEqual(
|
||||
len(replacement_nodes_in_graph), len(replacement_nodes_in_res)
|
||||
)
|
||||
self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res)
|
||||
return len(replacement_nodes_in_graph)
|
||||
|
||||
# match without filter, should find 2 match
|
||||
traced = symbolic_trace(M())
|
||||
matches = subgraph_rewriter.replace_pattern_with_filters(
|
||||
traced,
|
||||
BinaryOpScalarReLUPattern,
|
||||
BinaryOpScalarReLUReplacement,
|
||||
None)
|
||||
traced, BinaryOpScalarReLUPattern, BinaryOpScalarReLUReplacement, None
|
||||
)
|
||||
self.assertEqual(len(matches), 2)
|
||||
self.assertEqual(check_replacement_nodes(self, traced, matches), 2)
|
||||
|
||||
@ -870,7 +873,8 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
traced,
|
||||
BinaryOpScalarReLUPattern,
|
||||
BinaryOpScalarReLUReplacement,
|
||||
[second_input_is_scalar])
|
||||
[second_input_is_scalar],
|
||||
)
|
||||
self.assertEqual(len(matches), 1)
|
||||
self.assertEqual(check_replacement_nodes(self, traced, matches), 1)
|
||||
|
||||
@ -890,10 +894,13 @@ class TestSubgraphRewriter(JitTestCase):
|
||||
|
||||
self.assertEqual(len(matches), 1)
|
||||
|
||||
self.assertExpectedInline(traced.code.strip(), """\
|
||||
self.assertExpectedInline(
|
||||
traced.code.strip(),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
_reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(x, [3, 4], [1, 2]); x = None
|
||||
return _reshape_alias_copy_default_1""") # noqa: B950
|
||||
return _reshape_alias_copy_default_1""",
|
||||
) # noqa: B950
|
||||
|
||||
def test_replacement_with_attrs(self):
|
||||
class M(torch.nn.Module):
|
||||
@ -928,11 +935,15 @@ def forward(self, x):
|
||||
def test_matching_variable_arguments(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices.default(x, [2, 2], stride=[2, 2])
|
||||
return torch.ops.aten.max_pool2d_with_indices.default(
|
||||
x, [2, 2], stride=[2, 2]
|
||||
)
|
||||
|
||||
def pattern(x, kernel_size, stride):
|
||||
# default padding is [0, 0]
|
||||
return torch.ops.aten.max_pool2d_with_indices.default(x, kernel_size, stride, padding=[0, 0])
|
||||
return torch.ops.aten.max_pool2d_with_indices.default(
|
||||
x, kernel_size, stride, padding=[0, 0]
|
||||
)
|
||||
|
||||
traced = symbolic_trace(M())
|
||||
matches = subgraph_rewriter.replace_pattern(traced, pattern, pattern)
|
||||
@ -951,12 +962,20 @@ def forward(self, x):
|
||||
return torch.sub(torch.mul(x, y), y)
|
||||
|
||||
traced = symbolic_trace(M())
|
||||
matches = subgraph_rewriter.replace_pattern_with_filters(traced, pattern, replacement)
|
||||
matches = subgraph_rewriter.replace_pattern_with_filters(
|
||||
traced, pattern, replacement
|
||||
)
|
||||
|
||||
def check_replacement_nodes(self, traced, matches):
|
||||
replacement_nodes_in_graph = [node for node in traced.graph.nodes if node.target in {torch.sub, torch.mul}]
|
||||
replacement_nodes_in_graph = [
|
||||
node
|
||||
for node in traced.graph.nodes
|
||||
if node.target in {torch.sub, torch.mul}
|
||||
]
|
||||
replacement_nodes_in_res = [r for m in matches for r in m.replacements]
|
||||
self.assertEqual(len(replacement_nodes_in_graph), len(replacement_nodes_in_res))
|
||||
self.assertEqual(
|
||||
len(replacement_nodes_in_graph), len(replacement_nodes_in_res)
|
||||
)
|
||||
self.assertEqual(replacement_nodes_in_graph, replacement_nodes_in_res)
|
||||
return len(replacement_nodes_in_graph)
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user