mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is the `__torch_dispatch__` subclass used for tracing by AOTAutograd (https://github.com/pytorch/functorch/blob/main/functorch/_src/python_key.py). Given that a couple of folks are now interested in using this infra, it seems like a good idea to put it in core, and focus our efforts on a single implementation. I put this up as a WIP, just for discussion, but some questions off the top of my head. 1. What should be the intended way of extending this tracer? Should we define extension points, or should folks simply copy paste and modify? If we do define extension points, what are the extension points we should define? 2. There are some open questions about the way we're overriding FX to resolve some lingering issues (i.e. dealing with `nn.Parameter` and `call_module` calls). @ezyang implemented an alternate version of this tensor in https://github.com/albanD/subclass_zoo/blob/main/tracer_tensor.py, but it appears he ran into some issues with it that led to me submitting this implementation. That being said, I think some of the things over there should still be ported. 3. Given that this is going to be shared infra, what other features should we put in here? One that comes to mind is to allow for meta-tensor tracing (perhaps by default?), with a more solid fallback. Some of the other implementations (for reference on requirements). 1. FX2TRT: D34868356 (internal only) 2. Edge's? @gmagogsfm cc: @ezyang , @jamesr66a , @zou3519 , @gmagogsfm, @842974287 Pull Request resolved: https://github.com/pytorch/pytorch/pull/74360 Approved by: https://github.com/ezyang
1686 lines
63 KiB
Python
1686 lines
63 KiB
Python
# Owner(s): ["oncall: fx"]
|
|
|
|
import math
|
|
import numbers
|
|
import operator
|
|
import sys
|
|
import unittest
|
|
from typing import Callable, Dict, Union, List, Optional
|
|
from types import BuiltinFunctionType
|
|
|
|
import torch
|
|
import torch.fx.experimental.optimization as optimization
|
|
from torch.fx._symbolic_trace import symbolic_trace
|
|
from torch.fx.experimental import merge_matmul
|
|
from torch.fx.experimental.accelerator_partitioner import Partitioner
|
|
from torch.fx.experimental.normalize import NormalizeOperators, NormalizeArgs
|
|
from torch.fx.passes import graph_manipulation
|
|
from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
|
|
from torch.fx.experimental.partitioner_utils import (
|
|
NodeLatency,
|
|
get_partition_to_latency_mapping,
|
|
get_latency_of_partitioned_graph,
|
|
Device,
|
|
PartitionerConfig,
|
|
PartitionMode,
|
|
)
|
|
from torch.fx.experimental.rewriter import RewritingTracer
|
|
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
|
|
from torch.fx.experimental.meta_tracer import MetaTracer
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.node import Node
|
|
from torch.fx.operator_schemas import (
|
|
_torchscript_type_to_python_type,
|
|
normalize_function,
|
|
normalize_module,
|
|
type_matches,
|
|
create_type_hint,
|
|
)
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
|
|
from torch.fx.passes.split_module import split_module
|
|
from torch.testing._internal.common_device_type import (
|
|
ops,
|
|
onlyCPU,
|
|
instantiate_device_type_tests,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
try:
|
|
import torchvision.models
|
|
from torchvision.models import resnet18
|
|
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
skipIfNoMkldnn = unittest.skipIf(
|
|
not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()),
|
|
"no MKLDNN",
|
|
)
|
|
|
|
|
|
def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
|
|
return GraphModule(
|
|
root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
|
|
RewritingTracer().trace(root),
|
|
)
|
|
|
|
|
|
class TestFXExperimental(JitTestCase):
|
|
def test_serialize_graph(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.e = torch.rand(4)
|
|
self.conv = torch.nn.Conv2d(3, 3, 2, bias=False)
|
|
|
|
def forward(self, a, b, c):
|
|
add_1 = a + b
|
|
conv1 = self.conv(c)
|
|
linear = self.linear(add_1 + conv1)
|
|
add_2 = linear + self.e
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
c = torch.rand(3, 3, 2, 2)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b, c])
|
|
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 5000, 0), Device("dev_1", 125, 1)]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
# Fix for now to add type/shape to output
|
|
for node in traced.graph.nodes:
|
|
if node.op == "output":
|
|
node.meta["tensor_meta"] = _extract_tensor_metadata(a)
|
|
for mod in module_with_submodules.modules():
|
|
if isinstance(mod, GraphModule):
|
|
for node in mod.graph.nodes:
|
|
node.meta["tensor_meta"] = _extract_tensor_metadata(a)
|
|
for node in module_with_submodules.graph.nodes:
|
|
node.meta["tensor_meta"] = _extract_tensor_metadata(a)
|
|
|
|
weights1 = {}
|
|
weights2 = {}
|
|
serialized_graph1 = graph_manipulation.serialize_module(traced, weights1)
|
|
serialized_graph2 = graph_manipulation.serialize_module(
|
|
module_with_submodules, weights2
|
|
)
|
|
assert len(weights1) == 4
|
|
assert len(weights2) == 4
|
|
assert len(serialized_graph1["nodes"]) == 10
|
|
assert len(serialized_graph1["weights"]) == 4
|
|
assert len(serialized_graph1["modules"]) == 0
|
|
assert len(serialized_graph2["nodes"]) == 6
|
|
assert len(serialized_graph2["weights"]) == 4
|
|
assert len(serialized_graph2["modules"]) == 1
|
|
assert serialized_graph1["weights"]["linear.weight"]["shape"] == "[4, 4]"
|
|
assert serialized_graph1["weights"]["linear.weight"]["dtype"] == "torch.float32"
|
|
assert serialized_graph1["weights"]["linear.weight"]["is_quantized"] is False
|
|
assert serialized_graph1["nodes"][0]["shape"] == "[4]"
|
|
assert serialized_graph1["nodes"][0]["dtype"] == "torch.float32"
|
|
assert serialized_graph1["nodes"][0]["target"] == "a"
|
|
assert serialized_graph1["nodes"][0]["op_code"] == "placeholder"
|
|
assert serialized_graph1["nodes"][0]["name"] == "a"
|
|
assert serialized_graph1["nodes"][6]["args"][0]["name"] == "add_1"
|
|
assert serialized_graph1["nodes"][6]["args"][0]["is_node"] is True
|
|
|
|
# Test the users of the nodes. No users of the last/output node.
|
|
assert serialized_graph2["nodes"][0]["users"][0]["name"] == "submod_0"
|
|
assert serialized_graph2["nodes"][1]["users"][0]["name"] == "submod_0"
|
|
assert serialized_graph2["nodes"][4]["users"][0]["name"] == "output"
|
|
assert serialized_graph2["nodes"][5]["users"] == []
|
|
|
|
# Test quantization info serialization.
|
|
x = torch.tensor([[-1.0, 0.0], [1.0, 2.0]])
|
|
q_tensor = torch.quantize_per_tensor(x, 1, 0, torch.qint32)
|
|
q_tensor_channel = torch.quantize_per_channel(
|
|
x, torch.tensor([0.1, 0.01]), torch.tensor([10, 0]), 0, torch.quint8
|
|
)
|
|
result, _ = graph_manipulation.serialize_tensor_quantization(
|
|
q_tensor, weights={}, pcq_prefix="foo"
|
|
)
|
|
result2, per_channel_dict = graph_manipulation.serialize_tensor_quantization(
|
|
q_tensor_channel, weights={}, pcq_prefix="bar"
|
|
)
|
|
assert result["qscheme"] == "torch.per_tensor_affine"
|
|
assert result["q_scale"] == 1.0
|
|
assert result2["qscheme"] == "torch.per_channel_affine"
|
|
assert result2["q_per_channel_scales"] == "bar_per_channel_scales"
|
|
assert per_channel_dict["bar_per_channel_zero_points"]["shape"] == "[2]"
|
|
|
|
def test_find_single_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(1)
|
|
b = torch.rand(1)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 125, 0),
|
|
Device("dev_1", 150, 1),
|
|
Device("dev_2", 125, 2),
|
|
]
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
assert dag.nodes[0].logical_device_ids == [1]
|
|
|
|
def test_lack_of_devices(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
catch_runtime_error = False
|
|
try:
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
except RuntimeError:
|
|
catch_runtime_error = True
|
|
assert catch_runtime_error
|
|
|
|
def test_large_node_error(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
linear = self.linear(a)
|
|
add = linear + a
|
|
return add
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 40, 0),
|
|
Device("dev_1", 40, 0),
|
|
Device("dev_2", 40, 0),
|
|
Device("dev_3", 40, 0),
|
|
Device("dev_4", 40, 0),
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
catch_runtime_error = False
|
|
try:
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
except RuntimeError:
|
|
catch_runtime_error = True
|
|
assert catch_runtime_error
|
|
|
|
def test_partition_node_manipulation(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
add_1 = a + b
|
|
add_2 = add_1 + torch.rand(4)
|
|
add_3 = add_2 + torch.rand(4)
|
|
return add_3
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a, b = torch.rand(4), torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 1000, 0)]
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
partition = partitioner.partitions[0]
|
|
assert partition.used_mem_bytes == 112
|
|
# Select add_2 node to remove
|
|
selected_node = None
|
|
for node in partition.nodes:
|
|
if node.name == "add_2":
|
|
selected_node = node
|
|
partition.remove_node(selected_node)
|
|
assert partition.used_mem_bytes == 80
|
|
|
|
def test_size_based_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
self.c = torch.rand(4)
|
|
|
|
def forward(self, a, b):
|
|
add_1 = a + b
|
|
linear = self.linear(add_1)
|
|
add_2 = linear + self.c
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b])
|
|
partitioner = Partitioner()
|
|
devices = [
|
|
Device("dev_0", 125, 0),
|
|
Device("dev_1", 125, 1),
|
|
Device("dev_2", 125, 2),
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b), module_with_submodules(a, b))
|
|
for i, node in enumerate(dag.nodes):
|
|
assert node.logical_device_ids == [i]
|
|
|
|
def test_partition_device_mapping(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
b = torch.rand(4)
|
|
add_1 = a + b
|
|
linear_1 = self.linear(add_1)
|
|
add_2 = torch.rand(4) + a
|
|
add_3 = add_2 + linear_1
|
|
return add_3
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
partitioner = Partitioner()
|
|
devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
for i, node in enumerate(dag.nodes):
|
|
if i == 1:
|
|
assert node.logical_device_ids == [1]
|
|
else:
|
|
assert node.logical_device_ids == [0]
|
|
|
|
def test_sparse_nn_partition(self):
|
|
class MyRecommendationModule(torch.nn.Module):
|
|
def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
|
|
layers = torch.nn.ModuleList()
|
|
for _ in range(num_of_layers):
|
|
ll = torch.nn.Linear(input_size, output_size)
|
|
layers.append(ll)
|
|
layers.append(torch.nn.ReLU())
|
|
return layers
|
|
|
|
def __init__(self):
|
|
super(MyRecommendationModule, self).__init__()
|
|
layers = self.create_mlp(4, 4, 4)
|
|
self.bottom_layers = torch.nn.Sequential(*layers)
|
|
layers = self.create_mlp(3, 24, 24)
|
|
self.top_layers = torch.nn.Sequential(*layers)
|
|
self.embedding_layers = torch.nn.ModuleList()
|
|
el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
for i in range(3):
|
|
el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
|
|
self.embedding_layers.append(el)
|
|
|
|
def forward(self, a, b, offset):
|
|
x = self.bottom_layers(a)
|
|
y = []
|
|
c = []
|
|
for i in range(len(self.embedding_layers)):
|
|
temp = torch.randint(10, (8,))
|
|
c.append(temp + b)
|
|
for i in range(len(self.embedding_layers)):
|
|
if i % 2 == 0:
|
|
y.append(self.embedding_layers[i](c[i], offset))
|
|
else:
|
|
y.append(
|
|
self.embedding_layers[i](torch.randint(10, (8,)), offset)
|
|
)
|
|
z = torch.cat([x] + y, dim=1)
|
|
p = self.top_layers(z)
|
|
return p
|
|
|
|
m = MyRecommendationModule()
|
|
a = torch.rand(2, 4)
|
|
b = torch.randint(10, (8,))
|
|
offset = torch.randint(1, (2,))
|
|
traced = symbolic_trace(m)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset])
|
|
devices = [
|
|
Device("dev_0", 33000000, 0),
|
|
Device("dev_1", 33000000, 1),
|
|
Device("dev_2", 33000000, 2),
|
|
]
|
|
partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
|
|
assert len(module_with_submodules.graph.nodes) == 24
|
|
|
|
def test_partition_latency(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + torch.rand(4)
|
|
add_2 = add_1 + torch.rand(4)
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + linear_1
|
|
add_4 = add_2 + add_3
|
|
return add_4
|
|
|
|
def get_node_to_latency_mapping(fx_module: GraphModule):
|
|
"""Given a fx module, generate node latency for each node
|
|
based on the size of each node
|
|
"""
|
|
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
|
|
for node in fx_module.graph.nodes:
|
|
if node.op not in {"output", "placeholder", "get_attr"}:
|
|
if node.size_bytes.total_size == node.size_bytes.output_size:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
|
|
)
|
|
else:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, node.size_bytes.output_size
|
|
)
|
|
return node_to_latency_mapping
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
node_to_latency_mapping = get_node_to_latency_mapping(traced)
|
|
devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
|
|
partitioner = Partitioner()
|
|
partitioner_config = PartitionerConfig(devices)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
partitions = partitioner.partitions
|
|
partition_to_latency_mapping = get_partition_to_latency_mapping(
|
|
partitions, node_to_latency_mapping
|
|
)
|
|
for p in partition_to_latency_mapping:
|
|
if p.partition_id == 0:
|
|
assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
|
|
else:
|
|
assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
|
|
transfer_rate_bytes_per_sec = 2
|
|
critical_path_latency_sec = get_latency_of_partitioned_graph(
|
|
partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
|
|
)
|
|
assert critical_path_latency_sec == 208.0
|
|
|
|
def test_cost_aware_partition(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + torch.rand(4)
|
|
add_2 = add_1 + torch.rand(4)
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + torch.rand(4)
|
|
add_4 = add_2 + linear_1
|
|
add_5 = add_3 + add_4
|
|
return add_5
|
|
|
|
def get_node_to_latency_mapping(fx_module: GraphModule):
|
|
node_to_latency_mapping: Dict[Node, NodeLatency] = {}
|
|
for node in fx_module.graph.nodes:
|
|
if node.op not in {"output", "placeholder", "get_attr"}:
|
|
if node.size_bytes.total_size == node.size_bytes.output_size:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, 1
|
|
)
|
|
else:
|
|
node_to_latency_mapping[node] = NodeLatency(
|
|
node.size_bytes.total_size, node.size_bytes.output_size
|
|
)
|
|
return node_to_latency_mapping
|
|
|
|
m = MyModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
devices = [
|
|
Device("dev_0", 125, 0),
|
|
Device("dev_1", 125, 1),
|
|
Device("dev_2", 125, 2),
|
|
Device("dev_3", 125, 3),
|
|
]
|
|
node_to_latency_mapping = get_node_to_latency_mapping(traced)
|
|
partitioner_config = PartitionerConfig(
|
|
devices,
|
|
mode=PartitionMode.cost_aware,
|
|
transfer_rate_bytes_per_sec=2,
|
|
node_to_latency_mapping=node_to_latency_mapping,
|
|
)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
partitions = partitioner.partitions
|
|
partition_to_latency_mapping = get_partition_to_latency_mapping(
|
|
partitions, node_to_latency_mapping
|
|
)
|
|
critical_path_latency_sec = get_latency_of_partitioned_graph(
|
|
partitions,
|
|
partition_to_latency_mapping,
|
|
partitioner_config.transfer_rate_bytes_per_sec,
|
|
)
|
|
assert critical_path_latency_sec == 160.0
|
|
|
|
def test_aot_based_partition(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.b = torch.rand(4)
|
|
self.c = torch.rand(4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + self.b
|
|
add_2 = self.c + add_1
|
|
return add_2
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
node_to_partition_id = {}
|
|
partition_to_logical_devices = {}
|
|
count = 0
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
for node in traced.graph.nodes:
|
|
if node.op not in {"placeholder", "get_attr", "output"}:
|
|
node_to_partition_id[node] = count
|
|
partition_to_logical_devices[count] = [0]
|
|
count += 1
|
|
devices = [Device("dev_0", 200, 0)]
|
|
partitioner_config = PartitionerConfig(
|
|
devices=devices,
|
|
mode=PartitionMode.aot_based,
|
|
node_to_partition_mapping=node_to_partition_id,
|
|
partition_to_logical_device_mapping=partition_to_logical_devices,
|
|
)
|
|
partitioner = Partitioner()
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
dag = ret.dag
|
|
self.assertEqual(module_with_submodules(a), traced(a))
|
|
for node in dag.nodes:
|
|
assert node.size_bytes == 48
|
|
assert node.logical_device_ids == [0]
|
|
|
|
def test_replace_target_nodes_with(self):
|
|
class testModule(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return a + b
|
|
|
|
m = testModule()
|
|
traced = symbolic_trace(m)
|
|
input1 = torch.randn(1)
|
|
input2 = torch.randn(1)
|
|
assert (input1 + input2) == traced(input1, input2)
|
|
graph_manipulation.replace_target_nodes_with(
|
|
fx_module=traced,
|
|
old_op="call_function",
|
|
old_target=operator.add,
|
|
new_op="call_function",
|
|
new_target=operator.mul,
|
|
)
|
|
assert (input1 * input2) == traced(input1, input2)
|
|
|
|
def test_saturate_host(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a):
|
|
add_1 = a + torch.rand(4)
|
|
add_2 = add_1 + torch.rand(4)
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + linear_1
|
|
add_4 = add_2 + add_3
|
|
return add_4
|
|
|
|
m = TestModule()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(4)
|
|
graph_manipulation.get_size_of_all_nodes(traced, [a])
|
|
devices = [
|
|
Device("dev_0", 200, 0),
|
|
Device("dev_1", 200, 1),
|
|
Device("dev_2", 100, 2),
|
|
Device("dev_3", 100, 3),
|
|
Device("dev_4", 200, 4),
|
|
Device("dev_5", 100, 5),
|
|
]
|
|
partitioner = Partitioner()
|
|
# Without host saturation, the model will be split into two partitions.
|
|
# dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes.
|
|
partitioner_config = PartitionerConfig(devices, saturate_host=True)
|
|
ret = partitioner.partition_graph(traced, m, partitioner_config)
|
|
module_with_submodules = ret.module_with_submodules
|
|
self.assertEqual(traced(a), module_with_submodules(a))
|
|
|
|
partitions = partitioner.partitions
|
|
self.assertEqual(len(partitions), 2)
|
|
# With host saturation, partition 1 will be replicated to dev_4, and partition 2
|
|
# will be replicated to dev_2.
|
|
self.assertEqual(partitions[0].logical_device_ids, [0, 4])
|
|
self.assertEqual(partitions[1].logical_device_ids, [1, 2])
|
|
|
|
@skipIfNoTorchVision
|
|
def test_conv_bn_fusion(self):
|
|
rn18 = resnet18().eval()
|
|
traced = symbolic_trace(rn18)
|
|
fused = optimization.fuse(traced)
|
|
|
|
self.assertTrue(
|
|
all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
|
|
)
|
|
|
|
N, C, H, W = 20, 3, 224, 224
|
|
inp = torch.randn(N, C, H, W)
|
|
|
|
self.assertEqual(fused(inp), rn18(inp))
|
|
|
|
def test_conv_bn_fusion_not_running_state(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.conv = torch.nn.Conv2d(32, 64, 3, stride=2)
|
|
self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
model = M().eval()
|
|
|
|
traced = symbolic_trace(model)
|
|
fused = optimization.fuse(traced)
|
|
inp = torch.randn([1, 32, 50, 50])
|
|
|
|
# bn need not be folded in conv
|
|
self.assertTrue(
|
|
any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
|
|
)
|
|
self.assertEqual(fused(inp), model(inp))
|
|
|
|
def test_call_to_assert_no_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint()
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, ""):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_meta_tracer(self):
|
|
mt = MetaTracer()
|
|
|
|
class MetaTracerTestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16)
|
|
self.layernorm = torch.nn.LayerNorm(16)
|
|
|
|
def forward(self, x):
|
|
emb = self.emb(x)
|
|
lol = self.layernorm(emb)
|
|
return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
|
|
|
|
mttm = MetaTracerTestModule()
|
|
for BS in [15, 35]:
|
|
x = torch.zeros(BS, dtype=torch.long).random_(42)
|
|
graph = mt.trace(mttm, meta_args={'x' : x.to(device='meta')})
|
|
gm = torch.fx.GraphModule(mttm, graph)
|
|
torch.testing.assert_close(gm(x), mttm(x))
|
|
|
|
def test_proxy_tensor(self):
|
|
def f(x):
|
|
val = x.cos().cos().sum()
|
|
return torch.autograd.grad(val, x)
|
|
|
|
traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
|
|
inp = torch.randn(3, requires_grad=True)
|
|
torch.testing.assert_close(traced_graph(inp), f(inp))
|
|
|
|
def test_call_to_assert_with_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b, "test message"
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint()
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, "test message"):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_call_to_assert_with_empty_msg(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
assert a == b, ""
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint()
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, ""):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_call_to_assert_with_multiline_message(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
error_msg = """
|
|
An error message with
|
|
terrible spacing
|
|
"""
|
|
assert a == b, error_msg
|
|
return a + b
|
|
|
|
m = M()
|
|
traced = symbolic_trace_with_rewrite(m)
|
|
|
|
# Make sure the graph is well-formed
|
|
traced.graph.lint()
|
|
|
|
# Check the IR to make sure there's a call_function node with target == "Assert"
|
|
self.assertTrue(
|
|
any(
|
|
node.op == "call_function" and node.target == torch._assert
|
|
for node in traced.graph.nodes
|
|
)
|
|
)
|
|
|
|
# Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
|
|
error_msg = """
|
|
An error message with
|
|
terrible spacing
|
|
"""
|
|
traced(3, 3)
|
|
with self.assertRaisesRegex(AssertionError, error_msg):
|
|
traced(3, 5)
|
|
|
|
# Confirm that the output is correct
|
|
self.assertEqual(traced(3, 3), m(3, 3))
|
|
|
|
def test_subgraph_creation(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4))
|
|
self.linear = torch.nn.Linear(4, 5)
|
|
|
|
def forward(self, x, y):
|
|
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
|
|
w = self.linear(y).clamp(min=0.0, max=1.0)
|
|
return z + w
|
|
|
|
# symbolically trace model
|
|
my_module = MyModule()
|
|
my_module_traced = symbolic_trace(my_module)
|
|
|
|
# random mod partitioning
|
|
partition_counter = 0
|
|
NPARTITIONS = 3
|
|
|
|
# Add some random meta info to make sure it is kept around.
|
|
for node in my_module_traced.graph.nodes:
|
|
if node.op != "output":
|
|
node.meta["test_meta_info"] = True
|
|
|
|
def mod_partition(node: Node):
|
|
nonlocal partition_counter
|
|
partition = partition_counter % NPARTITIONS
|
|
partition_counter = (partition_counter + 1) % NPARTITIONS
|
|
return partition
|
|
|
|
# split module in module with submodules
|
|
module_with_submodules = split_module(
|
|
my_module_traced, my_module, mod_partition
|
|
)
|
|
|
|
# Check that test_meta_info was still on all nodes.
|
|
submodules = dict(module_with_submodules.named_modules())
|
|
for node in module_with_submodules.graph.nodes:
|
|
if node.op == "call_module":
|
|
submod = submodules[node.target]
|
|
self.assertTrue(isinstance(submod, torch.fx.GraphModule))
|
|
for submod_node in submod.graph.nodes:
|
|
if submod_node.op != "output":
|
|
stored_op = submod_node.meta.get("test_meta_info")
|
|
self.assertTrue(stored_op is not None and stored_op)
|
|
|
|
x = torch.rand(3, 4)
|
|
y = torch.rand(3, 4)
|
|
|
|
orig_out = my_module_traced(x, y)
|
|
submodules_out = module_with_submodules(x, y)
|
|
|
|
self.assertEqual(orig_out, submodules_out)
|
|
|
|
def test_split_module_kwargs_expansion(self):
|
|
class ModuleWithKwargsExpansion(torch.nn.Module):
|
|
def forward(self, x, **kwargs):
|
|
return x + kwargs['foo']
|
|
|
|
mod = ModuleWithKwargsExpansion()
|
|
traced = torch.fx.symbolic_trace(mod)
|
|
|
|
seen_getitem = False
|
|
|
|
def split_callback(n):
|
|
nonlocal seen_getitem
|
|
split_idx = int(seen_getitem)
|
|
if n.target == operator.getitem:
|
|
seen_getitem = True
|
|
return split_idx
|
|
|
|
split = split_module(traced, mod, split_callback)
|
|
|
|
x = torch.randn(5, 3)
|
|
foo = torch.randn(5, 3)
|
|
torch.testing.assert_allclose(split(x, foo=foo), traced(x, foo=foo))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_subgraph_trivial_resnet(self):
|
|
# Smoke test trivially splitting resnet into 1 partition works
|
|
# There was an issue before causing submodule names to be aliased
|
|
m = resnet18()
|
|
traced = symbolic_trace(m)
|
|
a = torch.rand(64, 3, 7, 7)
|
|
module_with_submodules = split_module(traced, m, lambda node: 0)
|
|
module_with_submodules(a)
|
|
|
|
def test_split_module_default_arg(self):
|
|
class ModelToTrace(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lin = torch.nn.Linear(512, 512)
|
|
|
|
def forward(self, x, targets=None):
|
|
x = self.lin(x)
|
|
|
|
if targets is not None:
|
|
x = x + targets
|
|
|
|
return x
|
|
|
|
mtt = ModelToTrace()
|
|
traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None})
|
|
|
|
split = split_module(traced, mtt, lambda node: 0)
|
|
|
|
x = torch.randn(50, 512)
|
|
torch.testing.assert_allclose(split(x), traced(x))
|
|
|
|
def test_normalize_binary_operators(self):
|
|
ops_to_test = {
|
|
torch.add,
|
|
torch.mul,
|
|
torch.sub,
|
|
torch.div,
|
|
torch.floor_divide,
|
|
torch.remainder,
|
|
torch.eq,
|
|
torch.ne,
|
|
torch.lt,
|
|
torch.le,
|
|
torch.gt,
|
|
torch.ge,
|
|
}
|
|
|
|
# Test Tensor/Tensor callsite
|
|
for op in ops_to_test:
|
|
|
|
class WrapperMod(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return op(x, y)
|
|
|
|
traced = symbolic_trace(WrapperMod())
|
|
normalized = NormalizeOperators(traced).transform()
|
|
x, y = torch.randn(3, 4), torch.randn(3, 4)
|
|
torch.testing.assert_close(traced(x, y), normalized(x, y))
|
|
self.assertFalse(
|
|
any(n.target in ops_to_test for n in normalized.graph.nodes)
|
|
)
|
|
|
|
# Test Tensor/scalar callsite
|
|
for op in ops_to_test:
|
|
|
|
class WrapperMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
return op(x, 42)
|
|
|
|
traced = symbolic_trace(WrapperMod())
|
|
normalized = NormalizeOperators(traced).transform()
|
|
x = torch.randn(3, 4)
|
|
torch.testing.assert_close(traced(x), normalized(x))
|
|
self.assertFalse(
|
|
any(n.target in ops_to_test for n in normalized.graph.nodes)
|
|
)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_normalize_args(self):
|
|
m = resnet18()
|
|
|
|
class FunctionalTracer(torch.fx.Tracer):
|
|
def is_leaf_module(
|
|
self, m: torch.nn.Module, module_qualified_name: str
|
|
) -> bool:
|
|
# `leaves` contains the set of standard `nn.Modules` that are not
|
|
# currently symbolically traceable. Ideally this set would be empty
|
|
leaves = set([torch.nn.BatchNorm2d])
|
|
return type(m) in leaves
|
|
|
|
traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
|
|
|
|
input = torch.randn(5, 3, 224, 224)
|
|
ref_outs = traced(input)
|
|
|
|
ShapeProp(traced).propagate(input)
|
|
traced = NormalizeArgs(traced).transform()
|
|
|
|
modules = dict(traced.named_modules())
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_function" and node.target != operator.add:
|
|
self.assertEqual(len(node.args), 0)
|
|
elif node.op == "call_module":
|
|
submod_class = modules[node.target].__class__
|
|
nn_class = getattr(torch.nn, submod_class.__name__)
|
|
if submod_class == nn_class:
|
|
self.assertEqual(len(node.args), 0)
|
|
traced(input)
|
|
self.assertEqual(traced(input), ref_outs)
|
|
|
|
def test_normalize_modules_exhaustive(self):
|
|
"""
|
|
Exhaustively test `Node.normalized_arguments` on all standard
|
|
torch.nn Module classes
|
|
"""
|
|
for test_params in module_tests + new_module_tests:
|
|
if "constructor" not in test_params:
|
|
constructor = getattr(torch.nn, test_params["module_name"])
|
|
else:
|
|
constructor = test_params["constructor"]
|
|
|
|
if "constructor_args" not in test_params:
|
|
args = ()
|
|
else:
|
|
args = test_params["constructor_args"]
|
|
|
|
mod = constructor(*args)
|
|
# Skip modules that are not standard `torch.nn`
|
|
# instances, including functionals. (functionals
|
|
# are tested in test_normalize_args)
|
|
if mod.__class__.__name__ not in dir(torch.nn):
|
|
continue
|
|
|
|
if "input_fn" not in test_params:
|
|
inputs = torch.randn(test_params["input_size"])
|
|
else:
|
|
inputs = test_params["input_fn"]()
|
|
|
|
if not isinstance(inputs, (tuple, list)):
|
|
inputs = (inputs,)
|
|
|
|
params = ", ".join(f"v{i}" for i in range(len(inputs)))
|
|
|
|
# Generate a class to wrap this standard `nn.Module` instance
|
|
test_classname = f"Test{mod.__class__.__name__}"
|
|
test_mod_code = f"""
|
|
class {test_classname}(torch.nn.Module):
|
|
def __init__(self, mod):
|
|
super().__init__()
|
|
self.mod = mod
|
|
|
|
def forward(self, {params}):
|
|
return self.mod({params})
|
|
"""
|
|
|
|
gbls = {"torch": torch}
|
|
exec(test_mod_code, gbls)
|
|
|
|
test_instance = gbls[test_classname](mod)
|
|
traced = symbolic_trace(test_instance)
|
|
|
|
# Use `Node.normalized_arguments` to get a new set of arguments
|
|
# to feed to the Module. Then, rewrite the node to only take
|
|
# in those arguments as kwargs
|
|
modules = dict(traced.named_modules())
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_module":
|
|
submod_class = modules[node.target].__class__
|
|
nn_class = getattr(torch.nn, submod_class.__name__)
|
|
if submod_class == nn_class:
|
|
normalized_args = node.normalized_arguments(traced)
|
|
normalized_args2 = normalize_module(
|
|
traced, node.target, node.args, node.kwargs
|
|
)
|
|
assert normalized_args == normalized_args2
|
|
assert normalized_args
|
|
node.args = normalized_args.args
|
|
node.kwargs = normalized_args.kwargs
|
|
|
|
traced.recompile()
|
|
|
|
# These Modules have an RNG in their forward, so testing
|
|
# correctness by comparing outputs is not correct. Skip that
|
|
# check for these
|
|
stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"}
|
|
|
|
if mod.__class__.__name__ not in stochastic_modules:
|
|
self.assertEqual(traced(*inputs), mod(*inputs))
|
|
|
|
traced = NormalizeArgs(symbolic_trace(test_instance)).transform()
|
|
modules = dict(traced.named_modules())
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_module":
|
|
submod_class = modules[node.target].__class__
|
|
nn_class = getattr(torch.nn, submod_class.__name__)
|
|
if submod_class == nn_class:
|
|
self.assertEqual(len(node.args), 0)
|
|
|
|
def test_normalize_args_preserve_meta(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, a):
|
|
return torch.add(a, 3)
|
|
|
|
m = MyModule()
|
|
traced = symbolic_trace(m)
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.add:
|
|
node.meta["my_key"] = 7
|
|
break
|
|
else:
|
|
self.fail("Didn't find call_function torch.add")
|
|
|
|
input = torch.randn(2, 3)
|
|
ShapeProp(traced).propagate(input)
|
|
traced = NormalizeArgs(traced).transform()
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.add:
|
|
self.assertTrue("my_key" in node.meta)
|
|
self.assertEqual(node.meta["my_key"], 7)
|
|
break
|
|
else:
|
|
self.fail("Didn't find call_function torch.add")
|
|
|
|
@skipIfNoTorchVision
|
|
def test_annotate_returns_with_schema(self):
|
|
m = resnet18()
|
|
|
|
traced_modules = symbolic_trace(m)
|
|
traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform()
|
|
for node in traced_modules_annotated.graph.nodes:
|
|
if node.type is None:
|
|
check = (node.op, node.target)
|
|
self.assertIn(
|
|
check,
|
|
{
|
|
("placeholder", "x"),
|
|
("call_module", "maxpool"),
|
|
("call_function", operator.add),
|
|
("call_function", torch.flatten),
|
|
("output", "output"),
|
|
}
|
|
)
|
|
|
|
# Smoke test torchscript compilation since now we're emitting type annotations
|
|
torch.jit.script(traced_modules_annotated)
|
|
|
|
class FunctionalTracer(torch.fx.Tracer):
|
|
def is_leaf_module(
|
|
self, m: torch.nn.Module, module_qualified_name: str
|
|
) -> bool:
|
|
# `leaves` contains the set of standard `nn.Modules` that are not
|
|
# currently symbolically traceable. Ideally this set would be empty
|
|
leaves = set([torch.nn.BatchNorm2d])
|
|
return type(m) in leaves
|
|
|
|
traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
|
|
|
|
traced_functionals_annotated = AnnotateTypesWithSchema(
|
|
traced_functionals
|
|
).transform()
|
|
for node in traced_functionals_annotated.graph.nodes:
|
|
if node.type is None:
|
|
check = (node.op, node.target)
|
|
excluded_nodes = {
|
|
("placeholder", "x"),
|
|
# Return type differs based on boolean dispatch :(
|
|
("call_function", torch.nn.functional.max_pool2d),
|
|
("output", "output"),
|
|
}
|
|
# AnnotateTypesWithSchema doesn't work with bound C++ functions
|
|
if not isinstance(node.target, BuiltinFunctionType):
|
|
self.assertIn(check, excluded_nodes)
|
|
|
|
# Smoke test torchscript compilation since now we're emitting type annotations
|
|
torch.jit.script(traced_functionals_annotated)
|
|
|
|
def test_subgraph_uniquename(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, a, b, c, d):
|
|
add_1 = a + b
|
|
add_2 = add_1 + c
|
|
linear_1 = self.linear(add_1)
|
|
add_3 = add_2 + d
|
|
add_4 = add_2 + linear_1
|
|
add_5 = add_3 + add_4
|
|
return add_5
|
|
|
|
a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4)
|
|
mm = MyModule()
|
|
traced = symbolic_trace(mm)
|
|
|
|
def split_cb(node: torch.fx.Node):
|
|
if node.name == "a" or node.name == "b" or node.name == "add":
|
|
return 0
|
|
else:
|
|
return 1
|
|
|
|
module_with_submodule = split_module(traced, mm, split_cb)
|
|
self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
|
|
|
|
def test_split_qualname_mapping(self):
|
|
d_hid = 4
|
|
|
|
class ExampleCode(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
|
|
self.lin = torch.nn.Linear(d_hid, d_hid)
|
|
|
|
def forward(self, x):
|
|
x = torch.mm(x, self.mm_param)
|
|
x = torch.relu(x)
|
|
x = torch.mm(x, self.mm_param)
|
|
x = self.lin(x)
|
|
x = torch.relu(x)
|
|
x = torch.mm(x, self.mm_param2)
|
|
x = self.lin(x)
|
|
return x
|
|
|
|
my_module = ExampleCode()
|
|
my_module_traced = symbolic_trace(my_module)
|
|
|
|
part_idx = 0
|
|
|
|
def split_callback(n : torch.fx.Node):
|
|
nonlocal part_idx
|
|
if (n.op, n.target) == ('call_module', 'lin'):
|
|
part_idx += 1
|
|
return part_idx
|
|
|
|
# split module in module with submodules
|
|
qualname_map : Dict[str, str] = {}
|
|
module_with_submodules = split_module(
|
|
my_module_traced, my_module, split_callback, qualname_map
|
|
)
|
|
expected_qualname_map = {
|
|
'submod_1.lin': 'lin', 'submod_2.lin': 'lin'
|
|
}
|
|
self.assertEqual(qualname_map, expected_qualname_map)
|
|
|
|
def test_traceable_function_with_nonstandard_name(self):
|
|
def foo(x):
|
|
return torch.relu(x)
|
|
|
|
traced = symbolic_trace_with_rewrite(foo)
|
|
|
|
def test_to_folder(self):
|
|
class Test(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Test, self).__init__()
|
|
self.W = torch.nn.Parameter(torch.randn(2))
|
|
self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
self.attr = torch.randn(2)
|
|
self.register_buffer("attr2", torch.randn(2))
|
|
self.register_buffer("attr3", torch.ones(2, dtype=torch.int32))
|
|
|
|
def forward(self, x):
|
|
return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))
|
|
|
|
mod = symbolic_trace(Test())
|
|
module_name = "Foo"
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
tmp_dir = Path(tmp_dir)
|
|
mod.to_folder(tmp_dir, module_name)
|
|
# Recipe taken from here:
|
|
# https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
|
import importlib.util
|
|
|
|
spec = importlib.util.spec_from_file_location(
|
|
module_name, tmp_dir / "__init__.py"
|
|
)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
t = torch.randn(2, 2)
|
|
self.assertEqual(module.Foo()(t), mod(t))
|
|
|
|
def test_fetch(self):
|
|
attrs_for_lowering: Dict[str, List[str]] = {
|
|
"torch.nn.modules.conv.Conv2d": [
|
|
"weight",
|
|
"bias",
|
|
"kernel_size",
|
|
"stride",
|
|
"padding",
|
|
"dilation",
|
|
"groups",
|
|
"padding_mode",
|
|
],
|
|
"torch.nn.modules.batchnorm.BatchNorm2d": [
|
|
"weight",
|
|
"bias",
|
|
"running_mean",
|
|
"running_var",
|
|
"eps",
|
|
],
|
|
}
|
|
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 3, 2)
|
|
self.bn = torch.nn.BatchNorm2d(3)
|
|
|
|
def forward(self, a):
|
|
a = self.conv(a)
|
|
a += a
|
|
return self.bn(a)
|
|
|
|
mod = TestModule()
|
|
traced = symbolic_trace(mod)
|
|
lift_lowering_attrs_to_nodes(traced)
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_module":
|
|
assert hasattr(node, "attrs_for_lowering")
|
|
para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]
|
|
|
|
# node.attrs_for_lowering has an addition field of class name
|
|
assert len(para_list) + 1 == len(node.attrs_for_lowering)
|
|
for p_name in para_list:
|
|
assert p_name in node.attrs_for_lowering
|
|
|
|
def test_merge_matmuls(self):
|
|
"""
|
|
A collection of test cases for torch.fx.experimental.merge_matmul,
|
|
a graph transformation that merges matrix multiplication operations.
|
|
"""
|
|
# Utility function for counting matmuls for test assertions.
|
|
def _count_matmuls(mod):
|
|
gm = torch.fx.symbolic_trace(mod)
|
|
|
|
num_matmuls = 0
|
|
for node in gm.graph.nodes:
|
|
if node.target == torch.matmul:
|
|
num_matmuls += 1
|
|
|
|
return num_matmuls
|
|
|
|
# Simple test case in which there are two matmuls of the same size to merge.
|
|
class SimpleMergeMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, x, y):
|
|
a = torch.matmul(x, self.rhs)
|
|
b = torch.matmul(y, self.rhs)
|
|
return a + b
|
|
|
|
# Initialize inputs.
|
|
a = torch.randn(3, 3)
|
|
b = torch.randn(3, 3)
|
|
|
|
# Initialize RHS for matmuls.
|
|
rhs = torch.randn(3, 4)
|
|
|
|
# Construct SimpleMergeMatmulModule and call merge_matmul on it.
|
|
module = SimpleMergeMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(a, b)
|
|
after = opt_module(a, b)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; original module should have 2 matmuls
|
|
# and optimized module should have 1.
|
|
self.assertEqual(_count_matmuls(module), 2)
|
|
self.assertEqual(_count_matmuls(opt_module), 1)
|
|
|
|
# Test case in which there are multiple matmuls of different sizes to merge.
|
|
class FiveMergeMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, a, b, c, d, e):
|
|
s = torch.tensor([])
|
|
matmuls = []
|
|
|
|
# For some reason using a list comprehension or for-loop for this
|
|
# doesn't work.
|
|
matmuls.append(torch.matmul(a, self.rhs))
|
|
matmuls.append(torch.matmul(b, self.rhs))
|
|
matmuls.append(torch.matmul(c, self.rhs))
|
|
matmuls.append(torch.matmul(d, self.rhs))
|
|
matmuls.append(torch.matmul(e, self.rhs))
|
|
|
|
for m in matmuls:
|
|
s += torch.sum(m)
|
|
|
|
return s
|
|
|
|
# Initialize inputs.
|
|
inputs = [torch.randn(2 * i + 1, 5) for i in range(5)]
|
|
|
|
# Initialize RHS.
|
|
rhs = torch.randn(5, 4)
|
|
|
|
# Construct FiveMergeMatmulModule and call merge_matmul on it.
|
|
module = FiveMergeMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(*inputs)
|
|
after = opt_module(*inputs)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; original module should have len(inputs) matmuls
|
|
# and optimized module should have 1.
|
|
self.assertEqual(_count_matmuls(module), len(inputs))
|
|
self.assertEqual(_count_matmuls(opt_module), 1)
|
|
|
|
# Simple test case in which two matmuls cannot be merged due to a data dependency between
|
|
# the LHS operands.
|
|
class UnmergeableMatmulModule(torch.nn.Module):
|
|
def __init__(self, rhs):
|
|
super().__init__()
|
|
self.rhs = rhs
|
|
|
|
def forward(self, x):
|
|
a = torch.matmul(x, self.rhs)
|
|
a_abs = torch.abs(a)
|
|
b = torch.matmul(a_abs.transpose(1, 0), self.rhs)
|
|
return b
|
|
|
|
# Initialize inputs.
|
|
a = torch.randn(3, 3)
|
|
|
|
# Initialize RHS for matmuls.
|
|
rhs = torch.randn(3, 4)
|
|
|
|
# Construct UnmergeableMatmulModule and call merge_matmul on it.
|
|
module = UnmergeableMatmulModule(rhs)
|
|
opt_module = merge_matmul.merge_matmul(module)
|
|
|
|
# Numerical correctness check.
|
|
before = module(a)
|
|
after = opt_module(a)
|
|
before.allclose(after)
|
|
|
|
# Basic graph structure check; the number of matrix multiplcations should not have changed.
|
|
self.assertEqual(_count_matmuls(module), 2)
|
|
self.assertEqual(_count_matmuls(opt_module), 2)
|
|
|
|
def test_type_matches(self):
|
|
should_be_equal = [
|
|
(int, type(5)),
|
|
(numbers.Number, type(5)),
|
|
(numbers.Number, type(5.0)),
|
|
(int, type(torch.float)),
|
|
(Union[int, float], type(5)),
|
|
(Union[int, float], type(5.0)),
|
|
(List[int], type(5)),
|
|
(List[int], create_type_hint([int, int])),
|
|
(List[int], create_type_hint((int, int))),
|
|
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
|
|
(
|
|
List[torch.Tensor],
|
|
create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
|
|
),
|
|
(torch.Tensor, torch.nn.Parameter),
|
|
(List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
|
|
(List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
|
|
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
|
|
(
|
|
List[torch.Tensor],
|
|
create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
|
|
),
|
|
(torch.Tensor, torch.nn.Parameter),
|
|
(List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
|
|
(List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
|
|
(Optional[List[torch.Tensor]], List[torch.Tensor]),
|
|
(Optional[List[int]], List[int]),
|
|
]
|
|
for sig_type, arg_type in should_be_equal:
|
|
self.assertTrue(type_matches(sig_type, arg_type))
|
|
|
|
should_fail = [
|
|
(int, float),
|
|
(Union[int, float], str),
|
|
(List[torch.Tensor], List[int]),
|
|
]
|
|
|
|
for sig_type, arg_type in should_fail:
|
|
self.assertFalse(type_matches(sig_type, arg_type))
|
|
|
|
@skipIfNoMkldnn
|
|
def test_optimize_for_inference_cpu(self):
|
|
import torch.nn as nn
|
|
|
|
class Foo(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
layers = []
|
|
layers2 = []
|
|
for _ in range(10):
|
|
layers.append(nn.Conv2d(3, 3, 1))
|
|
layers.append(nn.BatchNorm2d(3))
|
|
layers.append(nn.ReLU())
|
|
|
|
layers2.append(nn.Conv2d(3, 3, 1))
|
|
layers2.append(nn.BatchNorm2d(3))
|
|
layers2.append(nn.ReLU())
|
|
self.model = nn.Sequential(*layers)
|
|
self.model2 = nn.Sequential(*layers2)
|
|
|
|
def forward(self, x):
|
|
return self.model(x) + self.model2(x)
|
|
|
|
N, C, H, W, = (
|
|
1,
|
|
3,
|
|
224,
|
|
224,
|
|
)
|
|
inp = torch.randn(N, C, H, W)
|
|
with torch.no_grad():
|
|
model = Foo().eval()
|
|
optimized_model = optimization.optimize_for_inference(model)
|
|
torch.testing.assert_close(model(inp), optimized_model(inp))
|
|
|
|
optimized_model2 = optimization.optimize_for_inference(
|
|
model, pass_config={"remove_dropout": False}
|
|
)
|
|
torch.testing.assert_close(model(inp), optimized_model2(inp))
|
|
|
|
@skipIfNoTorchVision
|
|
@skipIfNoMkldnn
|
|
def test_optimize_for_inference_cpu_torchvision(self):
|
|
models = [
|
|
torchvision.models.resnet18,
|
|
torchvision.models.resnet50,
|
|
torchvision.models.densenet121,
|
|
torchvision.models.shufflenet_v2_x1_0,
|
|
torchvision.models.vgg16,
|
|
torchvision.models.mobilenet_v2,
|
|
torchvision.models.mnasnet1_0,
|
|
torchvision.models.resnext50_32x4d,
|
|
]
|
|
with torch.no_grad():
|
|
for model_type in models:
|
|
model = model_type()
|
|
C, H, W, = (
|
|
3,
|
|
224,
|
|
224,
|
|
)
|
|
inp = torch.randn(3, C, H, W)
|
|
model(inp)
|
|
model.eval()
|
|
inp = torch.randn(1, C, H, W)
|
|
heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0)
|
|
optimized_model = optimization.optimize_for_inference(model)
|
|
|
|
orig_out = model(inp)
|
|
new_out = optimized_model(inp)
|
|
torch.testing.assert_close(orig_out, new_out)
|
|
|
|
|
|
class TestNormalizeOperators(JitTestCase):
|
|
@onlyCPU
|
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
|
def test_normalize_operator_exhaustive(self, device, dtype, op):
|
|
# These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
|
|
fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot"}
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
|
for sample_input in sample_inputs_itr:
|
|
unsupported_arg_type = False
|
|
arg_values = [sample_input.input] + list(sample_input.args)
|
|
kwarg_values = sample_input.kwargs
|
|
arg_types = []
|
|
kwarg_types = {}
|
|
|
|
def jit_infer_type(v):
|
|
inferred_arg_type = torch._C._jit_try_infer_type(v)
|
|
assert inferred_arg_type.success()
|
|
t = _torchscript_type_to_python_type(inferred_arg_type.type())
|
|
return t
|
|
|
|
for v in arg_values:
|
|
if isinstance(v, torch.Tensor):
|
|
arg_types.append(type(v))
|
|
else:
|
|
if isinstance(v, complex):
|
|
# Complex type not supported in FX
|
|
unsupported_arg_type = True
|
|
arg_types.append(jit_infer_type(v))
|
|
|
|
for k, v in kwarg_values.items():
|
|
if isinstance(v, torch.Tensor):
|
|
kwarg_types[k] = type(v)
|
|
else:
|
|
if isinstance(v, complex):
|
|
# Complex type not supported in FX
|
|
unsupported_arg_type = True
|
|
kwarg_types[k] = jit_infer_type(v)
|
|
|
|
if unsupported_arg_type:
|
|
continue
|
|
# Test normalize_function by itself
|
|
ref_out = op.op(*arg_values, **kwarg_values)
|
|
norm_args_and_kwargs = normalize_function(
|
|
op.op, arg_values, kwarg_values, arg_types, kwarg_types
|
|
)
|
|
if norm_args_and_kwargs is None:
|
|
raise RuntimeError(
|
|
"""
|
|
FX failed to normalize op - add the op to the op_skip list.
|
|
A common reason is if your OpInfo was implemented with a lambda
|
|
- otherwise, file an issue
|
|
"""
|
|
)
|
|
test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs)
|
|
self.assertEqual(test_out, ref_out)
|
|
|
|
# Test normalized_arguments as part of FX
|
|
if op.name in fx_fail:
|
|
continue
|
|
param_names = []
|
|
param_values = []
|
|
fx_args = []
|
|
for idx, v in enumerate(arg_values):
|
|
if isinstance(v, torch.Tensor):
|
|
param_names.append(f"arg_{idx}")
|
|
param_values.append(v)
|
|
fx_args.append(param_names[-1])
|
|
else:
|
|
fx_args.append(f"{repr(v)}")
|
|
|
|
for k, v in kwarg_values.items():
|
|
if isinstance(v, torch.Tensor):
|
|
param_names.append(k)
|
|
param_values.append(v)
|
|
fx_args.append(f"{k} = {k}")
|
|
else:
|
|
fx_args.append(f"{k} = {repr(v)}")
|
|
|
|
code = f"""
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, {', '.join(param_names)}):
|
|
return torch.{op.name}({', '.join(fx_args)})
|
|
"""
|
|
|
|
g = {"torch": torch, "inf": math.inf}
|
|
exec(code, g)
|
|
TestModule = g["TestModule"]
|
|
|
|
m = TestModule()
|
|
traced = torch.fx.symbolic_trace(m)
|
|
ref_out = traced(*param_values)
|
|
|
|
for node in traced.graph.nodes:
|
|
if node.op == "call_function":
|
|
normalized_args = node.normalized_arguments(
|
|
traced, arg_types, kwarg_types
|
|
)
|
|
assert normalized_args
|
|
node.args = normalized_args.args
|
|
node.kwargs = normalized_args.kwargs
|
|
traced.recompile()
|
|
|
|
test_out = traced(*param_values)
|
|
self.assertEqual(test_out, ref_out)
|
|
|
|
def test_normalize_quantized_eb(self):
|
|
target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
|
|
args = (
|
|
torch.empty((2, 3), dtype=torch.uint8),
|
|
torch.empty((2,), dtype=torch.int64),
|
|
torch.empty((2,), dtype=torch.int64),
|
|
)
|
|
norm_args_and_kwargs = normalize_function(
|
|
target, args, normalize_to_only_use_kwargs=True
|
|
)
|
|
self.assertTrue(norm_args_and_kwargs is not None)
|
|
self.assertEqual(
|
|
set(norm_args_and_kwargs.kwargs.keys()),
|
|
{
|
|
"weight",
|
|
"indices",
|
|
"offsets",
|
|
"scale_grad_by_freq",
|
|
"mode",
|
|
"pruned_weights",
|
|
"per_sample_weights",
|
|
"compressed_indices_mapping",
|
|
"include_last_offset",
|
|
},
|
|
)
|
|
self.assertEqual(norm_args_and_kwargs.args, tuple())
|
|
|
|
|
|
instantiate_device_type_tests(TestNormalizeOperators, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|