mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Fixes https://github.com/pytorch/pytorch/issues/155220 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160093 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			2130 lines
		
	
	
		
			78 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			2130 lines
		
	
	
		
			78 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Owner(s): ["module: fx"]
 | |
| # ruff: noqa: F841
 | |
| 
 | |
| import functools
 | |
| import math
 | |
| import numbers
 | |
| import operator
 | |
| import pickle
 | |
| import sys
 | |
| import sympy
 | |
| import tempfile
 | |
| import typing
 | |
| import unittest
 | |
| from types import BuiltinFunctionType
 | |
| from typing import Callable, NamedTuple, Optional, Union
 | |
| 
 | |
| import torch
 | |
| import torch.fx.experimental.meta_tracer
 | |
| import torch.fx.experimental.optimization as optimization
 | |
| from torch.fx._symbolic_trace import symbolic_trace
 | |
| from torch.fx.experimental import merge_matmul
 | |
| from torch.fx.experimental.accelerator_partitioner import Partitioner
 | |
| from torch.fx.experimental.proxy_tensor import make_fx
 | |
| from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators
 | |
| from torch.fx.experimental.partitioner_utils import (
 | |
|     Device,
 | |
|     get_latency_of_partitioned_graph,
 | |
|     get_partition_to_latency_mapping,
 | |
|     NodeLatency,
 | |
|     PartitionerConfig,
 | |
|     PartitionMode,
 | |
| )
 | |
| from torch.fx.experimental.rewriter import RewritingTracer
 | |
| from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
 | |
| from torch.fx.graph_module import GraphModule
 | |
| from torch.fx.node import Node
 | |
| from torch.fx.operator_schemas import (
 | |
|     _torchscript_type_to_python_type,
 | |
|     create_type_hint,
 | |
|     normalize_function,
 | |
|     normalize_module,
 | |
|     type_matches,
 | |
| )
 | |
| from torch.fx.passes import graph_manipulation
 | |
| from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
 | |
| from torch.fx.passes.shape_prop import ShapeProp
 | |
| from torch.fx.passes.split_module import split_module
 | |
| from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes
 | |
| from torch.testing._internal.common_device_type import (
 | |
|     instantiate_device_type_tests,
 | |
|     onlyCPU,
 | |
|     ops,
 | |
| )
 | |
| from torch.testing._internal.common_methods_invocations import op_db
 | |
| from torch.testing._internal.common_nn import module_tests, get_new_module_tests
 | |
| from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase, TEST_WITH_CROSSREF
 | |
| from torch.testing._internal.jit_utils import JitTestCase
 | |
| import torch.utils._pytree as pytree
 | |
| 
 | |
| try:
 | |
|     import torchvision.models
 | |
|     from torchvision.models import resnet18
 | |
| 
 | |
|     HAS_TORCHVISION = True
 | |
| except ImportError:
 | |
|     HAS_TORCHVISION = False
 | |
| skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
 | |
| skipIfNoMkldnn = unittest.skipIf(
 | |
|     not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()),
 | |
|     "no MKLDNN",
 | |
| )
 | |
| 
 | |
| 
 | |
| def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
 | |
|     return GraphModule(
 | |
|         root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
 | |
|         RewritingTracer().trace(root),
 | |
|     )
 | |
| 
 | |
| 
 | |
| class TestFXExperimental(JitTestCase):
 | |
|     def test_find_single_partition(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 return a + b
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(1)
 | |
|         b = torch.rand(1)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a, b])
 | |
|         partitioner = Partitioner()
 | |
|         devices = [
 | |
|             Device("dev_0", 125, 0),
 | |
|             Device("dev_1", 150, 1),
 | |
|             Device("dev_2", 125, 2),
 | |
|         ]
 | |
|         partitioner_config = PartitionerConfig(devices)
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         module_with_submodules = ret.module_with_submodules
 | |
|         dag = ret.dag
 | |
|         self.assertEqual(traced(a, b), module_with_submodules(a, b))
 | |
|         assert dag.nodes[0].logical_device_ids == [1]
 | |
| 
 | |
|     def test_lack_of_devices(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 return a + b
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(4)
 | |
|         b = torch.rand(4)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a, b])
 | |
|         partitioner = Partitioner()
 | |
|         devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)]
 | |
|         partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
 | |
|         catch_runtime_error = False
 | |
|         try:
 | |
|             ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         except RuntimeError:
 | |
|             catch_runtime_error = True
 | |
|         assert catch_runtime_error
 | |
| 
 | |
|     def test_large_node_error(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(4, 4)
 | |
| 
 | |
|             def forward(self, a):
 | |
|                 linear = self.linear(a)
 | |
|                 add = linear + a
 | |
|                 return add
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(4)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a])
 | |
|         partitioner = Partitioner()
 | |
|         devices = [
 | |
|             Device("dev_0", 40, 0),
 | |
|             Device("dev_1", 40, 0),
 | |
|             Device("dev_2", 40, 0),
 | |
|             Device("dev_3", 40, 0),
 | |
|             Device("dev_4", 40, 0),
 | |
|         ]
 | |
|         partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
 | |
|         catch_runtime_error = False
 | |
|         try:
 | |
|             ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         except RuntimeError:
 | |
|             catch_runtime_error = True
 | |
|         assert catch_runtime_error
 | |
| 
 | |
|     def test_partition_node_manipulation(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 add_1 = a + b
 | |
|                 add_2 = add_1 + torch.rand(4)
 | |
|                 add_3 = add_2 + torch.rand(4)
 | |
|                 return add_3
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a, b = torch.rand(4), torch.rand(4)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a, b])
 | |
|         partitioner = Partitioner()
 | |
|         devices = [Device("dev_0", 1000, 0)]
 | |
|         partitioner_config = PartitionerConfig(devices)
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         partition = partitioner.partitions[0]
 | |
|         assert partition.used_mem_bytes == 112
 | |
|         # Select add_2 node to remove
 | |
|         selected_node = None
 | |
|         for node in partition.nodes:
 | |
|             if node.name == "add_2":
 | |
|                 selected_node = node
 | |
|         partition.remove_node(selected_node)
 | |
|         assert partition.used_mem_bytes == 80
 | |
| 
 | |
|     def test_size_based_partition(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(4, 4)
 | |
|                 self.c = torch.rand(4)
 | |
| 
 | |
|             def forward(self, a, b):
 | |
|                 add_1 = a + b
 | |
|                 linear = self.linear(add_1)
 | |
|                 add_2 = linear + self.c
 | |
|                 return add_2
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(4)
 | |
|         b = torch.rand(4)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a, b])
 | |
|         partitioner = Partitioner()
 | |
|         devices = [
 | |
|             Device("dev_0", 125, 0),
 | |
|             Device("dev_1", 125, 1),
 | |
|             Device("dev_2", 125, 2),
 | |
|         ]
 | |
|         partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         module_with_submodules = ret.module_with_submodules
 | |
|         dag = ret.dag
 | |
|         self.assertEqual(traced(a, b), module_with_submodules(a, b))
 | |
|         for i, node in enumerate(dag.nodes):
 | |
|             assert node.logical_device_ids == [i]
 | |
| 
 | |
|     def test_partition_device_mapping(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(4, 4)
 | |
| 
 | |
|             def forward(self, a):
 | |
|                 b = torch.rand(4)
 | |
|                 add_1 = a + b
 | |
|                 linear_1 = self.linear(add_1)
 | |
|                 add_2 = torch.rand(4) + a
 | |
|                 add_3 = add_2 + linear_1
 | |
|                 return add_3
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(4)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a])
 | |
|         partitioner = Partitioner()
 | |
|         devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
 | |
|         partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         module_with_submodules = ret.module_with_submodules
 | |
|         dag = ret.dag
 | |
|         self.assertEqual(traced(a), module_with_submodules(a))
 | |
|         for i, node in enumerate(dag.nodes):
 | |
|             if i == 1:
 | |
|                 assert node.logical_device_ids == [1]
 | |
|             else:
 | |
|                 assert node.logical_device_ids == [0]
 | |
| 
 | |
|     def test_sparse_nn_partition(self):
 | |
|         class MyRecommendationModule(torch.nn.Module):
 | |
|             def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
 | |
|                 layers = torch.nn.ModuleList()
 | |
|                 for _ in range(num_of_layers):
 | |
|                     ll = torch.nn.Linear(input_size, output_size)
 | |
|                     layers.append(ll)
 | |
|                     layers.append(torch.nn.ReLU())
 | |
|                 return layers
 | |
| 
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 layers = self.create_mlp(4, 4, 4)
 | |
|                 self.bottom_layers = torch.nn.Sequential(*layers)
 | |
|                 layers = self.create_mlp(3, 24, 24)
 | |
|                 self.top_layers = torch.nn.Sequential(*layers)
 | |
|                 self.embedding_layers = torch.nn.ModuleList()
 | |
|                 el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
 | |
|                 self.embedding_layers.append(el)
 | |
|                 for i in range(3):
 | |
|                     el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True)
 | |
|                     self.embedding_layers.append(el)
 | |
|                 el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
 | |
|                 self.embedding_layers.append(el)
 | |
| 
 | |
|             def forward(self, a, b, offset):
 | |
|                 x = self.bottom_layers(a)
 | |
|                 y = []
 | |
|                 c = []
 | |
|                 for i in range(len(self.embedding_layers)):
 | |
|                     temp = torch.randint(10, (8,))
 | |
|                     c.append(temp + b)
 | |
|                 for i in range(len(self.embedding_layers)):
 | |
|                     if i % 2 == 0:
 | |
|                         y.append(self.embedding_layers[i](c[i], offset))
 | |
|                     else:
 | |
|                         y.append(
 | |
|                             self.embedding_layers[i](torch.randint(10, (8,)), offset)
 | |
|                         )
 | |
|                 z = torch.cat([x] + y, dim=1)
 | |
|                 p = self.top_layers(z)
 | |
|                 return p
 | |
| 
 | |
|         m = MyRecommendationModule()
 | |
|         a = torch.rand(2, 4)
 | |
|         b = torch.randint(10, (8,))
 | |
|         offset = torch.randint(1, (2,))
 | |
|         traced = symbolic_trace(m)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset])
 | |
|         devices = [
 | |
|             Device("dev_0", 33000000, 0),
 | |
|             Device("dev_1", 33000000, 1),
 | |
|             Device("dev_2", 33000000, 2),
 | |
|         ]
 | |
|         partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
 | |
|         partitioner = Partitioner()
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         module_with_submodules = ret.module_with_submodules
 | |
|         dag = ret.dag
 | |
|         self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
 | |
|         assert len(module_with_submodules.graph.nodes) == 24
 | |
| 
 | |
|     def test_partition_latency(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(4, 4)
 | |
| 
 | |
|             def forward(self, a):
 | |
|                 add_1 = a + torch.rand(4)
 | |
|                 add_2 = add_1 + torch.rand(4)
 | |
|                 linear_1 = self.linear(add_1)
 | |
|                 add_3 = add_2 + linear_1
 | |
|                 add_4 = add_2 + add_3
 | |
|                 return add_4
 | |
| 
 | |
|         def get_node_to_latency_mapping(fx_module: GraphModule):
 | |
|             """Given a fx module, generate node latency for each node
 | |
|             based on the size of each node
 | |
|             """
 | |
|             node_to_latency_mapping: dict[Node, NodeLatency] = {}
 | |
|             for node in fx_module.graph.nodes:
 | |
|                 if node.op not in {"output", "placeholder", "get_attr"}:
 | |
|                     if node.size_bytes.total_size == node.size_bytes.output_size:
 | |
|                         node_to_latency_mapping[node] = NodeLatency(
 | |
|                             node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
 | |
|                         )
 | |
|                     else:
 | |
|                         node_to_latency_mapping[node] = NodeLatency(
 | |
|                             node.size_bytes.total_size, node.size_bytes.output_size
 | |
|                         )
 | |
|             return node_to_latency_mapping
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(4)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a])
 | |
|         node_to_latency_mapping = get_node_to_latency_mapping(traced)
 | |
|         devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
 | |
|         partitioner = Partitioner()
 | |
|         partitioner_config = PartitionerConfig(devices)
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         module_with_submodules = ret.module_with_submodules
 | |
|         self.assertEqual(traced(a), module_with_submodules(a))
 | |
|         partitions = partitioner.partitions
 | |
|         partition_to_latency_mapping = get_partition_to_latency_mapping(
 | |
|             partitions, node_to_latency_mapping
 | |
|         )
 | |
|         for p in partition_to_latency_mapping:
 | |
|             if p.partition_id == 0:
 | |
|                 assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
 | |
|             else:
 | |
|                 assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
 | |
|         transfer_rate_bytes_per_sec = 2
 | |
|         critical_path_latency_sec = get_latency_of_partitioned_graph(
 | |
|             partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
 | |
|         )
 | |
|         assert critical_path_latency_sec == 208.0
 | |
| 
 | |
|     def test_cost_aware_partition(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(4, 4)
 | |
| 
 | |
|             def forward(self, a):
 | |
|                 add_1 = a + torch.rand(4)
 | |
|                 add_2 = add_1 + torch.rand(4)
 | |
|                 linear_1 = self.linear(add_1)
 | |
|                 add_3 = add_2 + torch.rand(4)
 | |
|                 add_4 = add_2 + linear_1
 | |
|                 add_5 = add_3 + add_4
 | |
|                 return add_5
 | |
| 
 | |
|         def get_node_to_latency_mapping(fx_module: GraphModule):
 | |
|             node_to_latency_mapping: dict[Node, NodeLatency] = {}
 | |
|             for node in fx_module.graph.nodes:
 | |
|                 if node.op not in {"output", "placeholder", "get_attr"}:
 | |
|                     if node.size_bytes.total_size == node.size_bytes.output_size:
 | |
|                         node_to_latency_mapping[node] = NodeLatency(
 | |
|                             node.size_bytes.total_size, 1
 | |
|                         )
 | |
|                     else:
 | |
|                         node_to_latency_mapping[node] = NodeLatency(
 | |
|                             node.size_bytes.total_size, node.size_bytes.output_size
 | |
|                         )
 | |
|             return node_to_latency_mapping
 | |
| 
 | |
|         m = MyModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(4)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a])
 | |
|         devices = [
 | |
|             Device("dev_0", 125, 0),
 | |
|             Device("dev_1", 125, 1),
 | |
|             Device("dev_2", 125, 2),
 | |
|             Device("dev_3", 125, 3),
 | |
|         ]
 | |
|         node_to_latency_mapping = get_node_to_latency_mapping(traced)
 | |
|         partitioner_config = PartitionerConfig(
 | |
|             devices,
 | |
|             mode=PartitionMode.cost_aware,
 | |
|             transfer_rate_bytes_per_sec=2,
 | |
|             node_to_latency_mapping=node_to_latency_mapping,
 | |
|         )
 | |
|         partitioner = Partitioner()
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         module_with_submodules = ret.module_with_submodules
 | |
|         dag = ret.dag
 | |
|         self.assertEqual(traced(a), module_with_submodules(a))
 | |
|         partitions = partitioner.partitions
 | |
|         partition_to_latency_mapping = get_partition_to_latency_mapping(
 | |
|             partitions, node_to_latency_mapping
 | |
|         )
 | |
|         critical_path_latency_sec = get_latency_of_partitioned_graph(
 | |
|             partitions,
 | |
|             partition_to_latency_mapping,
 | |
|             partitioner_config.transfer_rate_bytes_per_sec,
 | |
|         )
 | |
|         assert critical_path_latency_sec == 160.0
 | |
| 
 | |
|     def test_aot_based_partition(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.b = torch.rand(4)
 | |
|                 self.c = torch.rand(4)
 | |
| 
 | |
|             def forward(self, a):
 | |
|                 add_1 = a + self.b
 | |
|                 add_2 = self.c + add_1
 | |
|                 return add_2
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(4)
 | |
|         node_to_partition_id = {}
 | |
|         partition_to_logical_devices = {}
 | |
|         count = 0
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a])
 | |
|         for node in traced.graph.nodes:
 | |
|             if node.op not in {"placeholder", "get_attr", "output"}:
 | |
|                 node_to_partition_id[node] = count
 | |
|                 partition_to_logical_devices[count] = [0]
 | |
|                 count += 1
 | |
|         devices = [Device("dev_0", 200, 0)]
 | |
|         partitioner_config = PartitionerConfig(
 | |
|             devices=devices,
 | |
|             mode=PartitionMode.aot_based,
 | |
|             node_to_partition_mapping=node_to_partition_id,
 | |
|             partition_to_logical_device_mapping=partition_to_logical_devices,
 | |
|         )
 | |
|         partitioner = Partitioner()
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         module_with_submodules = ret.module_with_submodules
 | |
|         dag = ret.dag
 | |
|         self.assertEqual(module_with_submodules(a), traced(a))
 | |
|         for node in dag.nodes:
 | |
|             assert node.size_bytes == 48
 | |
|             assert node.logical_device_ids == [0]
 | |
| 
 | |
|     def test_replace_target_nodes_with(self):
 | |
|         class testModule(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 return a + b
 | |
| 
 | |
|         m = testModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         input1 = torch.randn(1)
 | |
|         input2 = torch.randn(1)
 | |
|         assert (input1 + input2) == traced(input1, input2)
 | |
|         graph_manipulation.replace_target_nodes_with(
 | |
|             fx_module=traced,
 | |
|             old_op="call_function",
 | |
|             old_target=operator.add,
 | |
|             new_op="call_function",
 | |
|             new_target=operator.mul,
 | |
|         )
 | |
|         assert (input1 * input2) == traced(input1, input2)
 | |
| 
 | |
|     def test_saturate_host(self):
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(4, 4)
 | |
| 
 | |
|             def forward(self, a):
 | |
|                 add_1 = a + torch.rand(4)
 | |
|                 add_2 = add_1 + torch.rand(4)
 | |
|                 linear_1 = self.linear(add_1)
 | |
|                 add_3 = add_2 + linear_1
 | |
|                 add_4 = add_2 + add_3
 | |
|                 return add_4
 | |
| 
 | |
|         m = TestModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(4)
 | |
|         graph_manipulation.get_size_of_all_nodes(traced, [a])
 | |
|         devices = [
 | |
|             Device("dev_0", 200, 0),
 | |
|             Device("dev_1", 200, 1),
 | |
|             Device("dev_2", 100, 2),
 | |
|             Device("dev_3", 100, 3),
 | |
|             Device("dev_4", 200, 4),
 | |
|             Device("dev_5", 100, 5),
 | |
|         ]
 | |
|         partitioner = Partitioner()
 | |
|         # Without host saturation, the model will be split into two partitions.
 | |
|         # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes.
 | |
|         partitioner_config = PartitionerConfig(devices, saturate_host=True)
 | |
|         ret = partitioner.partition_graph(traced, m, partitioner_config)
 | |
|         module_with_submodules = ret.module_with_submodules
 | |
|         self.assertEqual(traced(a), module_with_submodules(a))
 | |
| 
 | |
|         partitions = partitioner.partitions
 | |
|         self.assertEqual(len(partitions), 2)
 | |
|         # With host saturation, partition 1 will be replicated to dev_4, and partition 2
 | |
|         # will be replicated to dev_2.
 | |
|         self.assertEqual(partitions[0].logical_device_ids, [0, 4])
 | |
|         self.assertEqual(partitions[1].logical_device_ids, [1, 2])
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     def test_conv_bn_fusion(self):
 | |
|         rn18 = resnet18().eval()
 | |
|         traced = symbolic_trace(rn18)
 | |
|         fused = optimization.fuse(traced)
 | |
| 
 | |
|         self.assertTrue(
 | |
|             all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
 | |
|         )
 | |
| 
 | |
|         N, C, H, W = 20, 3, 224, 224
 | |
|         inp = torch.randn(N, C, H, W)
 | |
| 
 | |
|         self.assertEqual(fused(inp), rn18(inp))
 | |
| 
 | |
|     def test_conv_bn_fusion_not_running_state(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.conv = torch.nn.Conv2d(32, 64, 3, stride=2)
 | |
|                 self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.conv(x)
 | |
|                 x = self.bn(x)
 | |
|                 return x
 | |
| 
 | |
|         model = M().eval()
 | |
| 
 | |
|         traced = symbolic_trace(model)
 | |
|         fused = optimization.fuse(traced)
 | |
|         inp = torch.randn([1, 32, 50, 50])
 | |
| 
 | |
|         # bn need not be folded in conv
 | |
|         self.assertTrue(
 | |
|             any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
 | |
|         )
 | |
|         self.assertEqual(fused(inp), model(inp))
 | |
| 
 | |
|     def test_conv_bn_fusion_mixed_dtype(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16)
 | |
|                 self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = self.conv(x)
 | |
|                 x = self.bn(x)
 | |
|                 return x
 | |
| 
 | |
|         model = M().eval()
 | |
| 
 | |
|         traced = symbolic_trace(model)
 | |
|         fused = optimization.fuse(traced)
 | |
|         inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)
 | |
| 
 | |
|         self.assertTrue(
 | |
|             all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
 | |
|         )
 | |
|         self.assertEqual(fused(inp), model(inp))
 | |
| 
 | |
|     def test_call_to_assert_no_msg(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 assert a == b
 | |
|                 return a + b
 | |
| 
 | |
|         m = M()
 | |
|         traced = symbolic_trace_with_rewrite(m)
 | |
| 
 | |
|         # Make sure the graph is well-formed
 | |
|         traced.graph.lint()
 | |
| 
 | |
|         # Check the IR to make sure there's a call_function node with target == "Assert"
 | |
|         self.assertTrue(
 | |
|             any(
 | |
|                 node.op == "call_function" and node.target == torch._assert
 | |
|                 for node in traced.graph.nodes
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
 | |
|         traced(3, 3)
 | |
|         with self.assertRaisesRegex(AssertionError, ""):
 | |
|             traced(3, 5)
 | |
| 
 | |
|         # Confirm that the output is correct
 | |
|         self.assertEqual(traced(3, 3), m(3, 3))
 | |
| 
 | |
|     def test_meta_tracer(self):
 | |
|         class MetaTracerTestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16)
 | |
|                 self.layernorm = torch.nn.LayerNorm(16)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 emb = self.emb(x)
 | |
|                 emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device)
 | |
|                 lol = self.layernorm(emb)
 | |
|                 return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)
 | |
| 
 | |
|         mttm = MetaTracerTestModule()
 | |
|         for BS in [15, 35]:
 | |
|             x = torch.zeros(BS, dtype=torch.long).random_(42)
 | |
|             meta_args = {'x' : x.to(device='meta')}
 | |
|             gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args)
 | |
|             torch.testing.assert_close(gm(x), mttm(x))
 | |
| 
 | |
|             # Test serialization/deserialization
 | |
|             with tempfile.TemporaryDirectory() as tmp_dir:
 | |
|                 with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f:
 | |
|                     pickle.dump(gm, f)
 | |
| 
 | |
|                 with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f:
 | |
|                     loaded = pickle.load(f)
 | |
| 
 | |
|                 torch.testing.assert_close(loaded(x), mttm(x))
 | |
| 
 | |
| 
 | |
|     def test_call_to_assert_with_msg(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 assert a == b, "test message"
 | |
|                 return a + b
 | |
| 
 | |
|         m = M()
 | |
|         traced = symbolic_trace_with_rewrite(m)
 | |
| 
 | |
|         # Make sure the graph is well-formed
 | |
|         traced.graph.lint()
 | |
| 
 | |
|         # Check the IR to make sure there's a call_function node with target == "Assert"
 | |
|         self.assertTrue(
 | |
|             any(
 | |
|                 node.op == "call_function" and node.target == torch._assert
 | |
|                 for node in traced.graph.nodes
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
 | |
|         traced(3, 3)
 | |
|         with self.assertRaisesRegex(AssertionError, "test message"):
 | |
|             traced(3, 5)
 | |
| 
 | |
|         # Confirm that the output is correct
 | |
|         self.assertEqual(traced(3, 3), m(3, 3))
 | |
| 
 | |
|     def test_call_to_assert_with_empty_msg(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 assert a == b, ""
 | |
|                 return a + b
 | |
| 
 | |
|         m = M()
 | |
|         traced = symbolic_trace_with_rewrite(m)
 | |
| 
 | |
|         # Make sure the graph is well-formed
 | |
|         traced.graph.lint()
 | |
| 
 | |
|         # Check the IR to make sure there's a call_function node with target == "Assert"
 | |
|         self.assertTrue(
 | |
|             any(
 | |
|                 node.op == "call_function" and node.target == torch._assert
 | |
|                 for node in traced.graph.nodes
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
 | |
|         traced(3, 3)
 | |
|         with self.assertRaisesRegex(AssertionError, ""):
 | |
|             traced(3, 5)
 | |
| 
 | |
|         # Confirm that the output is correct
 | |
|         self.assertEqual(traced(3, 3), m(3, 3))
 | |
| 
 | |
|     def test_call_to_assert_with_multiline_message(self):
 | |
|         class M(torch.nn.Module):
 | |
|             def forward(self, a, b):
 | |
|                 error_msg = """
 | |
| An error message with
 | |
| terrible spacing
 | |
|                 """
 | |
|                 assert a == b, error_msg
 | |
|                 return a + b
 | |
| 
 | |
|         m = M()
 | |
|         traced = symbolic_trace_with_rewrite(m)
 | |
| 
 | |
|         # Make sure the graph is well-formed
 | |
|         traced.graph.lint()
 | |
| 
 | |
|         # Check the IR to make sure there's a call_function node with target == "Assert"
 | |
|         self.assertTrue(
 | |
|             any(
 | |
|                 node.op == "call_function" and node.target == torch._assert
 | |
|                 for node in traced.graph.nodes
 | |
|             )
 | |
|         )
 | |
| 
 | |
|         # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
 | |
|         error_msg = """
 | |
| An error message with
 | |
| terrible spacing
 | |
|     """
 | |
|         traced(3, 3)
 | |
|         with self.assertRaisesRegex(AssertionError, error_msg):
 | |
|             traced(3, 5)
 | |
| 
 | |
|         # Confirm that the output is correct
 | |
|         self.assertEqual(traced(3, 3), m(3, 3))
 | |
| 
 | |
|     def test_subgraph_creation(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.param = torch.nn.Parameter(torch.rand(3, 4))
 | |
|                 self.linear = torch.nn.Linear(4, 5)
 | |
| 
 | |
|             def forward(self, x, y):
 | |
|                 z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
 | |
|                 w = self.linear(y).clamp(min=0.0, max=1.0)
 | |
|                 return z + w
 | |
| 
 | |
|         # symbolically trace model
 | |
|         my_module = MyModule()
 | |
|         my_module_traced = symbolic_trace(my_module)
 | |
| 
 | |
|         # random mod partitioning
 | |
|         partition_counter = 0
 | |
|         NPARTITIONS = 3
 | |
| 
 | |
|         # Add some random meta info to make sure it is kept around.
 | |
|         for node in my_module_traced.graph.nodes:
 | |
|             if node.op != "output":
 | |
|                 node.meta["test_meta_info"] = True
 | |
| 
 | |
|         def mod_partition(node: Node):
 | |
|             nonlocal partition_counter
 | |
|             partition = partition_counter % NPARTITIONS
 | |
|             partition_counter = (partition_counter + 1) % NPARTITIONS
 | |
|             return partition
 | |
| 
 | |
|         # split module in module with submodules
 | |
|         module_with_submodules = split_module(
 | |
|             my_module_traced, my_module, mod_partition
 | |
|         )
 | |
| 
 | |
|         # Check that test_meta_info was still on all nodes.
 | |
|         submodules = dict(module_with_submodules.named_modules())
 | |
|         for node in module_with_submodules.graph.nodes:
 | |
|             if node.op == "call_module":
 | |
|                 submod = submodules[node.target]
 | |
|                 self.assertTrue(isinstance(submod, torch.fx.GraphModule))
 | |
|                 for submod_node in submod.graph.nodes:
 | |
|                     if submod_node.op != "output":
 | |
|                         stored_op = submod_node.meta.get("test_meta_info")
 | |
|                         self.assertTrue(stored_op is not None and stored_op)
 | |
| 
 | |
|         x = torch.rand(3, 4)
 | |
|         y = torch.rand(3, 4)
 | |
| 
 | |
|         orig_out = my_module_traced(x, y)
 | |
|         submodules_out = module_with_submodules(x, y)
 | |
| 
 | |
|         self.assertEqual(orig_out, submodules_out)
 | |
| 
 | |
|     def test_split_module_input_names(self):
 | |
|         class Mod(torch.nn.Module):
 | |
|             def forward(self, x, a0, a1, b0, b1, c0, c1):
 | |
|                 x = x + (a0 ** 2) + (a1 / 2)
 | |
|                 x = x + (b0 ** 2) + (b1 / 2)
 | |
|                 x = x + (c0 ** 2) + (c1 / 2)
 | |
|                 return x
 | |
| 
 | |
|         mod = Mod()
 | |
|         traced = torch.fx.symbolic_trace(mod)
 | |
| 
 | |
|         seen = 0
 | |
| 
 | |
|         def split(n):
 | |
|             nonlocal seen
 | |
|             result = seen // 4
 | |
|             seen += 1
 | |
|             return result
 | |
| 
 | |
|         split = split_module(traced, mod, split, keep_original_input_name=False)
 | |
| 
 | |
|         # All the submodules should take in the inputs in the same order.
 | |
|         args = [torch.tensor(2.), torch.tensor(3.), torch.tensor(4.)]
 | |
|         output0 = split.submod_0(*args)
 | |
|         output1 = split.submod_1(*args)
 | |
|         output2 = split.submod_2(*args)
 | |
|         self.assertEqual(output0, output1)
 | |
|         self.assertEqual(output1, output2)
 | |
| 
 | |
|         # Each submodule should have normalized input names
 | |
|         def check_ph(gm):
 | |
|             nodes = list(gm.graph.nodes)
 | |
|             self.assertEqual(nodes[0].target, "arg_0")
 | |
|             self.assertEqual(nodes[1].target, "arg_1")
 | |
|             self.assertEqual(nodes[2].target, "arg_2")
 | |
| 
 | |
|         check_ph(split.submod_0)
 | |
|         check_ph(split.submod_1)
 | |
|         check_ph(split.submod_2)
 | |
| 
 | |
|     def test_split_module_dead_code(self):
 | |
|         class ModWithDeadCode(torch.nn.Module):
 | |
|             def forward(self, x):
 | |
|                 output = x * 2  # we want this
 | |
|                 dead_line = x + 2  # this is dead
 | |
|                 return output
 | |
| 
 | |
|         mod = ModWithDeadCode()
 | |
|         traced = torch.fx.symbolic_trace(mod)
 | |
| 
 | |
|         # split into before (0), target (1), and after(2)
 | |
|         saw_mul = False
 | |
| 
 | |
|         def split_callback(n):
 | |
|             nonlocal saw_mul
 | |
|             if n.target == operator.mul:
 | |
|                 saw_mul = True
 | |
|                 return 1
 | |
| 
 | |
|             if not saw_mul:
 | |
|                 return 0
 | |
|             if saw_mul:
 | |
|                 return 2
 | |
| 
 | |
|         split = split_module(traced, mod, split_callback)
 | |
| 
 | |
|         x = torch.randn((5,))
 | |
|         torch.testing.assert_close(
 | |
|             split(x), traced(x)
 | |
|         )
 | |
| 
 | |
|     def test_split_module_return_node(self):
 | |
|         def foo(x):
 | |
|             x.add_(1)
 | |
| 
 | |
|         gm = make_fx(foo, tracing_mode="fake")(torch.randn(3,))
 | |
| 
 | |
|         def cb(_):
 | |
|             return 1
 | |
| 
 | |
|         sp_gm = split_module(gm, None, cb)
 | |
|         submod_gm = sp_gm.submod_1
 | |
|         for node in submod_gm.graph.nodes:
 | |
|             if node.op == "output":
 | |
|                 break
 | |
|         else:
 | |
|             raise RuntimeError("Expected the subgraph to have an output node.")
 | |
| 
 | |
| 
 | |
|     def test_split_module_kwargs_expansion(self):
 | |
|         class ModuleWithKwargsExpansion(torch.nn.Module):
 | |
|             def forward(self, x, **kwargs):
 | |
|                 return x + kwargs['foo']
 | |
| 
 | |
|         mod = ModuleWithKwargsExpansion()
 | |
|         traced = torch.fx.symbolic_trace(mod)
 | |
| 
 | |
|         seen_getitem = False
 | |
| 
 | |
|         def split_callback(n):
 | |
|             nonlocal seen_getitem
 | |
|             split_idx = int(seen_getitem)
 | |
|             if n.target == operator.getitem:
 | |
|                 seen_getitem = True
 | |
|             return split_idx
 | |
| 
 | |
|         split = split_module(traced, mod, split_callback)
 | |
| 
 | |
|         x = torch.randn(5, 3)
 | |
|         foo = torch.randn(5, 3)
 | |
|         torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo))
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     def test_subgraph_trivial_resnet(self):
 | |
|         # Smoke test trivially splitting resnet into 1 partition works
 | |
|         # There was an issue before causing submodule names to be aliased
 | |
|         m = resnet18()
 | |
|         traced = symbolic_trace(m)
 | |
|         a = torch.rand(64, 3, 7, 7)
 | |
|         module_with_submodules = split_module(traced, m, lambda node: 0)
 | |
|         module_with_submodules(a)
 | |
| 
 | |
|     def test_split_module_default_arg(self):
 | |
|         class ModelToTrace(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.lin = torch.nn.Linear(512, 512)
 | |
| 
 | |
|             def forward(self, x, targets=None):
 | |
|                 x = self.lin(x)
 | |
| 
 | |
|                 if targets is not None:
 | |
|                     x = x + targets
 | |
| 
 | |
|                 return x
 | |
| 
 | |
|         mtt = ModelToTrace()
 | |
|         traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None})
 | |
| 
 | |
|         split = split_module(traced, mtt, lambda node: 0)
 | |
| 
 | |
|         x = torch.randn(50, 512)
 | |
|         torch.testing.assert_close(split(x), traced(x))
 | |
| 
 | |
|     def test_split_module_keep_original_order_and_noop_graph(self):
 | |
|         # Verify that split_module returns a similar no-op graph
 | |
|         # for `keep_original_order={True|False}`.
 | |
|         def fn(x):
 | |
|             return (x,)
 | |
| 
 | |
|         g = make_fx(fn, tracing_mode="fake")(torch.randn(3, 3))
 | |
| 
 | |
|         # g.graph.print_tabular()
 | |
|         # opcode       name    target    args       kwargs
 | |
|         # -----------  ------  --------  ---------  --------
 | |
|         # placeholder  x_1     x_1       ()         {}
 | |
|         # output       output  output    ((x_1,),)  {}
 | |
| 
 | |
|         def _test_split_graph(split_gm):
 | |
|             # Verify that the split_gm has same structure as original
 | |
|             self.assertEqual(len(split_gm.graph.nodes), 2)
 | |
| 
 | |
|             nodes = list(split_gm.graph.nodes)
 | |
|             self.assertEqual(nodes[0].op, "placeholder")
 | |
|             self.assertEqual(nodes[1].op, "output")
 | |
| 
 | |
|         # `keep_original_order=False`
 | |
|         _test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=False))
 | |
| 
 | |
|         # `keep_original_order=True`
 | |
|         _test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True))
 | |
| 
 | |
|     @unittest.skipIf(TEST_WITH_CROSSREF, "See https://github.com/pytorch/pytorch/issues/160077")
 | |
|     def test_split_module_symint_dependency_handling(self):
 | |
|         # Based on the code from - transformers/models/granitemoe/modeling_granitemoe.py
 | |
|         class GraniteMoeTopKGating(torch.nn.Module):
 | |
|             def __init__(self, input_size: int, num_experts: int, top_k: int):
 | |
|                 super().__init__()
 | |
| 
 | |
|                 self.num_experts = num_experts
 | |
|                 self.input_size = input_size
 | |
|                 self.top_k = top_k
 | |
| 
 | |
|                 self.layer = torch.nn.Linear(input_size, num_experts, bias=False)
 | |
| 
 | |
|             def forward(self, hidden_states):
 | |
|                 # compute the top_k routing decision
 | |
|                 logits = self.layer(hidden_states).float()  # [batch_size x seq_len, num_experts]
 | |
|                 top_k_logits, top_k_indices = logits.topk(self.top_k, dim=1)  # [num_tokens, top_k]
 | |
|                 top_k_gates = torch.softmax(top_k_logits, dim=1).type_as(hidden_states)  # [num_tokens, top_k]
 | |
| 
 | |
|                 # compute number of input given to each expert
 | |
|                 zeros = torch.zeros(
 | |
|                     [top_k_gates.size(0), self.num_experts], dtype=top_k_gates.dtype, device=top_k_gates.device
 | |
|                 )  # [num_tokens, num_experts]
 | |
|                 gates = zeros.scatter(1, top_k_indices, 1)  # [num_tokens, num_experts]
 | |
|                 expert_size = gates.long().sum(0)  # [num_experts,]
 | |
|                 expert_size = expert_size.tolist()
 | |
| 
 | |
|                 # sort and group input tokens according to expert assignment
 | |
|                 top_k_experts = top_k_indices.flatten()  # [num_tokens * top_k]
 | |
|                 _, index_sorted_experts = top_k_experts.sort(0)  # [num_tokens * top_k]
 | |
|                 batch_index = index_sorted_experts.div(self.top_k, rounding_mode="trunc")  # [num_tokens * top_k]
 | |
| 
 | |
|                 # gather the gate values for grouped input tokens
 | |
|                 top_k_gates = top_k_gates.flatten()  # [num_tokens * top_k]
 | |
|                 batch_gates = top_k_gates[index_sorted_experts]  # [num_tokens * top_k]
 | |
| 
 | |
|                 return index_sorted_experts, batch_index, batch_gates, expert_size, logits
 | |
| 
 | |
|         class GraniteMoeMoE(torch.nn.Module):
 | |
|             def __init__(self):
 | |
|                 super().__init__()
 | |
| 
 | |
|                 self.input_size = 32
 | |
|                 self.num_local_experts = 4
 | |
| 
 | |
|                 num_experts_per_tok = 2
 | |
|                 self.router = GraniteMoeTopKGating(
 | |
|                     input_size=self.input_size,
 | |
|                     num_experts=self.num_local_experts,
 | |
|                     top_k=num_experts_per_tok,
 | |
|                 )
 | |
| 
 | |
|             def forward(self, layer_input):
 | |
|                 _, batch_index, _, expert_size, _ = self.router(layer_input)
 | |
|                 expert_inputs = layer_input[batch_index]
 | |
|                 return expert_inputs.split(expert_size, dim=0)
 | |
| 
 | |
|         moe = GraniteMoeMoE()
 | |
|         inp = torch.randn([32, 32])
 | |
| 
 | |
|         expected = moe(inp)
 | |
| 
 | |
|         PARTITION_ID = 0
 | |
|         PARTITION_OPS_CTR = 0
 | |
|         NODE_PARTITION_MAP = {}
 | |
| 
 | |
|         # `callback` is called multiple times with same `node` in `split_module`.
 | |
|         # Cache the result such that partition id is consistent across calls.
 | |
|         def callback(node) -> int:
 | |
|             nonlocal PARTITION_ID, PARTITION_OPS_CTR, NODE_PARTITION_MAP
 | |
|             if node in NODE_PARTITION_MAP:
 | |
|                 return NODE_PARTITION_MAP[node]
 | |
| 
 | |
|             if PARTITION_OPS_CTR % 5 == 0:
 | |
|                 PARTITION_ID += 1
 | |
| 
 | |
|             PARTITION_OPS_CTR += 1
 | |
| 
 | |
|             NODE_PARTITION_MAP[node] = PARTITION_ID
 | |
|             return PARTITION_ID
 | |
| 
 | |
|         def backend(gm, inps):
 | |
|             split_gm = split_module(gm, root_m=None, split_callback=callback,
 | |
|                                     keep_original_order=True, keep_original_node_name=True)
 | |
|             return split_gm
 | |
| 
 | |
|         actual = torch.compile(moe, backend=backend)(inp)
 | |
|         torch.testing.assert_close(actual, expected)
 | |
| 
 | |
|     def test_normalize_binary_operators(self):
 | |
|         ops_to_test = {
 | |
|             torch.add,
 | |
|             torch.mul,
 | |
|             torch.sub,
 | |
|             torch.div,
 | |
|             torch.floor_divide,
 | |
|             torch.remainder,
 | |
|             torch.eq,
 | |
|             torch.ne,
 | |
|             torch.lt,
 | |
|             torch.le,
 | |
|             torch.gt,
 | |
|             torch.ge,
 | |
|         }
 | |
| 
 | |
|         # Test Tensor/Tensor callsite
 | |
|         for op in ops_to_test:
 | |
| 
 | |
|             class WrapperMod(torch.nn.Module):
 | |
|                 def forward(self, x, y):
 | |
|                     return op(x, y)
 | |
| 
 | |
|             traced = symbolic_trace(WrapperMod())
 | |
|             normalized = NormalizeOperators(traced).transform()
 | |
|             x, y = torch.randn(3, 4), torch.randn(3, 4)
 | |
|             torch.testing.assert_close(traced(x, y), normalized(x, y))
 | |
|             self.assertFalse(
 | |
|                 any(n.target in ops_to_test for n in normalized.graph.nodes)
 | |
|             )
 | |
| 
 | |
|         # Test Tensor/scalar callsite
 | |
|         for op in ops_to_test:
 | |
| 
 | |
|             class WrapperMod(torch.nn.Module):
 | |
|                 def forward(self, x):
 | |
|                     return op(x, 42)
 | |
| 
 | |
|             traced = symbolic_trace(WrapperMod())
 | |
|             normalized = NormalizeOperators(traced).transform()
 | |
|             x = torch.randn(3, 4)
 | |
|             torch.testing.assert_close(traced(x), normalized(x))
 | |
|             self.assertFalse(
 | |
|                 any(n.target in ops_to_test for n in normalized.graph.nodes)
 | |
|             )
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     def test_normalize_args(self):
 | |
|         m = resnet18()
 | |
| 
 | |
|         class FunctionalTracer(torch.fx.Tracer):
 | |
|             def is_leaf_module(
 | |
|                 self, m: torch.nn.Module, module_qualified_name: str
 | |
|             ) -> bool:
 | |
|                 # `leaves` contains the set of standard `nn.Modules` that are not
 | |
|                 # currently symbolically traceable. Ideally this set would be empty
 | |
|                 leaves = {torch.nn.BatchNorm2d}
 | |
|                 return type(m) in leaves
 | |
| 
 | |
|         traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
 | |
| 
 | |
|         input = torch.randn(5, 3, 224, 224)
 | |
|         ref_outs = traced(input)
 | |
| 
 | |
|         ShapeProp(traced).propagate(input)
 | |
|         traced = NormalizeArgs(traced).transform()
 | |
| 
 | |
|         modules = dict(traced.named_modules())
 | |
| 
 | |
|         for node in traced.graph.nodes:
 | |
|             if node.op == "call_function" and node.target != operator.add:
 | |
|                 self.assertEqual(len(node.args), 0)
 | |
|             elif node.op == "call_module":
 | |
|                 submod_class = modules[node.target].__class__
 | |
|                 nn_class = getattr(torch.nn, submod_class.__name__)
 | |
|                 if submod_class == nn_class:
 | |
|                     self.assertEqual(len(node.args), 0)
 | |
|         traced(input)
 | |
|         self.assertEqual(traced(input), ref_outs)
 | |
| 
 | |
|     def test_normalize_modules_exhaustive(self):
 | |
|         """
 | |
|         Exhaustively test `Node.normalized_arguments` on all standard
 | |
|         torch.nn Module classes
 | |
|         """
 | |
|         for test_params in module_tests + get_new_module_tests():
 | |
|             if "constructor" not in test_params:
 | |
|                 constructor = getattr(torch.nn, test_params["module_name"])
 | |
|             else:
 | |
|                 constructor = test_params["constructor"]
 | |
| 
 | |
|             if "constructor_args" not in test_params:
 | |
|                 args = ()
 | |
|             else:
 | |
|                 args = test_params["constructor_args"]
 | |
| 
 | |
|             mod = constructor(*args)
 | |
|             # Skip modules that are not standard `torch.nn`
 | |
|             # instances, including functionals. (functionals
 | |
|             # are tested in test_normalize_args)
 | |
|             if mod.__class__.__name__ not in dir(torch.nn):
 | |
|                 continue
 | |
| 
 | |
|             if "input_fn" not in test_params:
 | |
|                 inputs = torch.randn(test_params["input_size"])
 | |
|             else:
 | |
|                 inputs = test_params["input_fn"]()
 | |
| 
 | |
|             if not isinstance(inputs, (tuple, list)):
 | |
|                 inputs = (inputs,)
 | |
| 
 | |
|             params = ", ".join(f"v{i}" for i in range(len(inputs)))
 | |
| 
 | |
|             # Generate a class to wrap this standard `nn.Module` instance
 | |
|             test_classname = f"Test{mod.__class__.__name__}"
 | |
|             test_mod_code = f"""
 | |
| class {test_classname}(torch.nn.Module):
 | |
|     def __init__(self, mod):
 | |
|         super().__init__()
 | |
|         self.mod = mod
 | |
| 
 | |
|     def forward(self, {params}):
 | |
|         return self.mod({params})
 | |
|             """
 | |
| 
 | |
|             gbls = {"torch": torch}
 | |
|             exec(test_mod_code, gbls)
 | |
| 
 | |
|             test_instance = gbls[test_classname](mod)
 | |
|             traced = symbolic_trace(test_instance)
 | |
| 
 | |
|             # Use `Node.normalized_arguments` to get a new set of arguments
 | |
|             # to feed to the Module. Then, rewrite the node to only take
 | |
|             # in those arguments as kwargs
 | |
|             modules = dict(traced.named_modules())
 | |
|             for node in traced.graph.nodes:
 | |
|                 if node.op == "call_module":
 | |
|                     submod_class = modules[node.target].__class__
 | |
|                     nn_class = getattr(torch.nn, submod_class.__name__)
 | |
|                     if submod_class == nn_class:
 | |
|                         normalized_args = node.normalized_arguments(traced)
 | |
|                         normalized_args2 = normalize_module(
 | |
|                             traced, node.target, node.args, node.kwargs
 | |
|                         )
 | |
|                         assert normalized_args == normalized_args2
 | |
|                         assert normalized_args
 | |
|                         node.args = normalized_args.args
 | |
|                         node.kwargs = normalized_args.kwargs
 | |
| 
 | |
|             traced.recompile()
 | |
| 
 | |
|             # These Modules have an RNG in their forward, so testing
 | |
|             # correctness by comparing outputs is not correct. Skip that
 | |
|             # check for these
 | |
|             stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"}
 | |
| 
 | |
|             if mod.__class__.__name__ not in stochastic_modules:
 | |
|                 self.assertEqual(traced(*inputs), mod(*inputs))
 | |
| 
 | |
|             traced = NormalizeArgs(symbolic_trace(test_instance)).transform()
 | |
|             modules = dict(traced.named_modules())
 | |
|             for node in traced.graph.nodes:
 | |
|                 if node.op == "call_module":
 | |
|                     submod_class = modules[node.target].__class__
 | |
|                     nn_class = getattr(torch.nn, submod_class.__name__)
 | |
|                     if submod_class == nn_class:
 | |
|                         self.assertEqual(len(node.args), 0)
 | |
| 
 | |
|     def test_normalize_args_preserve_meta(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def forward(self, a):
 | |
|                 return torch.add(a, 3)
 | |
| 
 | |
|         m = MyModule()
 | |
|         traced = symbolic_trace(m)
 | |
| 
 | |
|         for node in traced.graph.nodes:
 | |
|             if node.op == "call_function" and node.target == torch.add:
 | |
|                 node.meta["my_key"] = 7
 | |
|                 break
 | |
|         else:
 | |
|             self.fail("Didn't find call_function torch.add")
 | |
| 
 | |
|         input = torch.randn(2, 3)
 | |
|         ShapeProp(traced).propagate(input)
 | |
|         traced = NormalizeArgs(traced).transform()
 | |
| 
 | |
|         for node in traced.graph.nodes:
 | |
|             if node.op == "call_function" and node.target == torch.add:
 | |
|                 self.assertTrue("my_key" in node.meta)
 | |
|                 self.assertEqual(node.meta["my_key"], 7)
 | |
|                 break
 | |
|         else:
 | |
|             self.fail("Didn't find call_function torch.add")
 | |
| 
 | |
|     def test_normalize_args_perserve_type(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def forward(self, a: list[torch.Tensor]):
 | |
|                 return torch.add(a[0], a[1])
 | |
| 
 | |
|         m = MyModule()
 | |
|         traced = symbolic_trace(m)
 | |
|         traced = NormalizeArgs(traced).transform()
 | |
| 
 | |
|         for node in traced.graph.nodes:
 | |
|             if node.op == "placeholder":
 | |
|                 self.assertEqual(node.type, list[torch.Tensor])
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     def test_annotate_returns_with_schema(self):
 | |
|         m = resnet18()
 | |
| 
 | |
|         traced_modules = symbolic_trace(m)
 | |
|         traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform()
 | |
|         for node in traced_modules_annotated.graph.nodes:
 | |
|             if node.type is None:
 | |
|                 check = (node.op, node.target)
 | |
|                 self.assertIn(
 | |
|                     check,
 | |
|                     {
 | |
|                         ("placeholder", "x"),
 | |
|                         ("call_module", "maxpool"),
 | |
|                         ("call_function", operator.add),
 | |
|                         ("call_function", torch.flatten),
 | |
|                         ("output", "output"),
 | |
|                     }
 | |
|                 )
 | |
| 
 | |
|         # Smoke test torchscript compilation since now we're emitting type annotations
 | |
|         torch.jit.script(traced_modules_annotated)
 | |
| 
 | |
|         class FunctionalTracer(torch.fx.Tracer):
 | |
|             def is_leaf_module(
 | |
|                 self, m: torch.nn.Module, module_qualified_name: str
 | |
|             ) -> bool:
 | |
|                 # `leaves` contains the set of standard `nn.Modules` that are not
 | |
|                 # currently symbolically traceable. Ideally this set would be empty
 | |
|                 leaves = {torch.nn.BatchNorm2d}
 | |
|                 return type(m) in leaves
 | |
| 
 | |
|         traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))
 | |
| 
 | |
|         traced_functionals_annotated = AnnotateTypesWithSchema(
 | |
|             traced_functionals
 | |
|         ).transform()
 | |
|         for node in traced_functionals_annotated.graph.nodes:
 | |
|             if node.type is None:
 | |
|                 check = (node.op, node.target)
 | |
|                 excluded_nodes = {
 | |
|                     ("placeholder", "x"),
 | |
|                     # Return type differs based on boolean dispatch :(
 | |
|                     ("call_function", torch.nn.functional.max_pool2d),
 | |
|                     ("output", "output"),
 | |
|                 }
 | |
|                 # AnnotateTypesWithSchema doesn't work with bound C++ functions
 | |
|                 if not isinstance(node.target, BuiltinFunctionType):
 | |
|                     self.assertIn(check, excluded_nodes)
 | |
| 
 | |
|         # Smoke test torchscript compilation since now we're emitting type annotations
 | |
|         torch.jit.script(traced_functionals_annotated)
 | |
| 
 | |
|     def test_annotate_getitem_node(self):
 | |
|         class CustomType:
 | |
|             pass
 | |
| 
 | |
|         class CustomNamedTuple(NamedTuple):
 | |
|             x: int
 | |
|             y: float
 | |
| 
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple):
 | |
|                 inp_0 = inp[0]
 | |
|                 inp_1 = inp[1]
 | |
|                 inp2_0 = inp2[0]
 | |
|                 inp3_x = inp3.x
 | |
|                 inp3_y = inp3.y
 | |
|                 return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
 | |
| 
 | |
|         class MyModule2(torch.nn.Module):
 | |
|             def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple):
 | |
|                 inp_0 = inp[0]
 | |
|                 inp_1 = inp[1]
 | |
|                 inp2_0 = inp2[0]
 | |
|                 inp3_x = inp3.x
 | |
|                 inp3_y = inp3.y
 | |
|                 return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y
 | |
| 
 | |
|         my_module = MyModule()
 | |
|         my_module_traced = torch.fx.symbolic_trace(my_module)
 | |
| 
 | |
|         # by default, fx transform loses type annotation of getitem nodes.
 | |
|         for node in my_module_traced.graph.nodes:
 | |
|             if node.target == operator.getitem:
 | |
|                 assert node.type is None
 | |
| 
 | |
|         annotate_getitem_nodes(my_module_traced.graph)
 | |
| 
 | |
|         for node in my_module_traced.graph.nodes:
 | |
|             if node.target == operator.getitem:
 | |
|                 self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
 | |
| 
 | |
|         my_module = MyModule2()
 | |
|         my_module_traced = torch.fx.symbolic_trace(my_module)
 | |
| 
 | |
|         # by default, fx transform loses type annotation of getitem nodes.
 | |
|         for node in my_module_traced.graph.nodes:
 | |
|             if node.target == operator.getitem:
 | |
|                 assert node.type is None
 | |
| 
 | |
|         annotate_getitem_nodes(my_module_traced.graph)
 | |
| 
 | |
|         for node in my_module_traced.graph.nodes:
 | |
|             if node.target == operator.getitem:
 | |
|                 self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")
 | |
| 
 | |
|     def test_subgraph_uniquename(self):
 | |
|         class MyModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.linear = torch.nn.Linear(4, 4)
 | |
| 
 | |
|             def forward(self, a, b, c, d):
 | |
|                 add_1 = a + b
 | |
|                 add_2 = add_1 + c
 | |
|                 linear_1 = self.linear(add_1)
 | |
|                 add_3 = add_2 + d
 | |
|                 add_4 = add_2 + linear_1
 | |
|                 add_5 = add_3 + add_4
 | |
|                 return add_5
 | |
| 
 | |
|         a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4)
 | |
|         mm = MyModule()
 | |
|         traced = symbolic_trace(mm)
 | |
| 
 | |
|         def split_cb(node: torch.fx.Node):
 | |
|             if node.name == "a" or node.name == "b" or node.name == "add":
 | |
|                 return 0
 | |
|             else:
 | |
|                 return 1
 | |
| 
 | |
|         module_with_submodule = split_module(traced, mm, split_cb)
 | |
|         self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))
 | |
| 
 | |
|     def test_split_qualname_mapping(self):
 | |
|         d_hid = 4
 | |
| 
 | |
|         class ExampleCode(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
 | |
|                 self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
 | |
|                 self.lin = torch.nn.Linear(d_hid, d_hid)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 x = torch.mm(x, self.mm_param)
 | |
|                 x = torch.relu(x)
 | |
|                 x = torch.mm(x, self.mm_param)
 | |
|                 x = self.lin(x)
 | |
|                 x = torch.relu(x)
 | |
|                 x = torch.mm(x, self.mm_param2)
 | |
|                 x = self.lin(x)
 | |
|                 return x
 | |
| 
 | |
|         my_module = ExampleCode()
 | |
|         my_module_traced = symbolic_trace(my_module)
 | |
| 
 | |
|         part_idx = 0
 | |
| 
 | |
|         def split_callback(n : torch.fx.Node):
 | |
|             nonlocal part_idx
 | |
|             if (n.op, n.target) == ('call_module', 'lin'):
 | |
|                 part_idx += 1
 | |
|             return part_idx
 | |
| 
 | |
|         # split module in module with submodules
 | |
|         qualname_map : dict[str, str] = {}
 | |
|         module_with_submodules = split_module(
 | |
|             my_module_traced, my_module, split_callback, qualname_map
 | |
|         )
 | |
|         expected_qualname_map = {
 | |
|             'submod_1.lin': 'lin', 'submod_2.lin': 'lin'
 | |
|         }
 | |
|         self.assertEqual(qualname_map, expected_qualname_map)
 | |
| 
 | |
|     def test_traceable_function_with_nonstandard_name(self):
 | |
|         def foo(x):
 | |
|             return torch.relu(x)
 | |
| 
 | |
|         traced = symbolic_trace_with_rewrite(foo)
 | |
| 
 | |
|     def test_to_folder(self):
 | |
|         class Test(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.W = torch.nn.Parameter(torch.randn(2))
 | |
|                 self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
 | |
|                 self.linear = torch.nn.Linear(2, 2)
 | |
|                 self.attr = torch.randn(2)
 | |
|                 self.attr2 = torch.nn.Buffer(torch.randn(2))
 | |
|                 self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32))
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))
 | |
| 
 | |
|         mod = symbolic_trace(Test())
 | |
|         module_name = "Foo"
 | |
|         import tempfile
 | |
|         from pathlib import Path
 | |
| 
 | |
|         with tempfile.TemporaryDirectory() as tmp_dir:
 | |
|             tmp_dir = Path(tmp_dir)
 | |
|             mod.to_folder(tmp_dir, module_name)
 | |
|             # Recipe taken from here:
 | |
|             # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
 | |
|             import importlib.util
 | |
| 
 | |
|             spec = importlib.util.spec_from_file_location(
 | |
|                 module_name, tmp_dir / "__init__.py"
 | |
|             )
 | |
|             module = importlib.util.module_from_spec(spec)
 | |
|             sys.modules[module_name] = module
 | |
|             spec.loader.exec_module(module)
 | |
|             t = torch.randn(2, 2)
 | |
|             self.assertEqual(module.Foo()(t), mod(t))
 | |
| 
 | |
|     def test_fetch(self):
 | |
|         attrs_for_lowering: dict[str, list[str]] = {
 | |
|             "torch.nn.modules.conv.Conv2d": [
 | |
|                 "weight",
 | |
|                 "bias",
 | |
|                 "kernel_size",
 | |
|                 "stride",
 | |
|                 "padding",
 | |
|                 "dilation",
 | |
|                 "groups",
 | |
|                 "padding_mode",
 | |
|             ],
 | |
|             "torch.nn.modules.batchnorm.BatchNorm2d": [
 | |
|                 "weight",
 | |
|                 "bias",
 | |
|                 "running_mean",
 | |
|                 "running_var",
 | |
|                 "eps",
 | |
|             ],
 | |
|         }
 | |
| 
 | |
|         class TestModule(torch.nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 self.conv = torch.nn.Conv2d(3, 3, 2)
 | |
|                 self.bn = torch.nn.BatchNorm2d(3)
 | |
| 
 | |
|             def forward(self, a):
 | |
|                 a = self.conv(a)
 | |
|                 a += a
 | |
|                 return self.bn(a)
 | |
| 
 | |
|         mod = TestModule()
 | |
|         traced = symbolic_trace(mod)
 | |
|         lift_lowering_attrs_to_nodes(traced)
 | |
| 
 | |
|         for node in traced.graph.nodes:
 | |
|             if node.op == "call_module":
 | |
|                 assert hasattr(node, "attrs_for_lowering")
 | |
|                 para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]
 | |
| 
 | |
|                 # node.attrs_for_lowering has an addition field of class name
 | |
|                 assert len(para_list) + 1 == len(node.attrs_for_lowering)
 | |
|                 for p_name in para_list:
 | |
|                     assert p_name in node.attrs_for_lowering
 | |
| 
 | |
|     def test_merge_matmuls(self):
 | |
|         """
 | |
|         A collection of test cases for torch.fx.experimental.merge_matmul,
 | |
|         a graph transformation that merges matrix multiplication operations.
 | |
|         """
 | |
|         # Utility function for counting matmuls for test assertions.
 | |
|         def _count_matmuls(mod):
 | |
|             gm = torch.fx.symbolic_trace(mod)
 | |
| 
 | |
|             num_matmuls = 0
 | |
|             for node in gm.graph.nodes:
 | |
|                 if node.target == torch.matmul:
 | |
|                     num_matmuls += 1
 | |
| 
 | |
|             return num_matmuls
 | |
| 
 | |
|         # Simple test case in which there are two matmuls of the same size to merge.
 | |
|         class SimpleMergeMatmulModule(torch.nn.Module):
 | |
|             def __init__(self, rhs):
 | |
|                 super().__init__()
 | |
|                 self.rhs = rhs
 | |
| 
 | |
|             def forward(self, x, y):
 | |
|                 a = torch.matmul(x, self.rhs)
 | |
|                 b = torch.matmul(y, self.rhs)
 | |
|                 return a + b
 | |
| 
 | |
|         # Initialize inputs.
 | |
|         a = torch.randn(3, 3)
 | |
|         b = torch.randn(3, 3)
 | |
| 
 | |
|         # Initialize RHS for matmuls.
 | |
|         rhs = torch.randn(3, 4)
 | |
| 
 | |
|         # Construct SimpleMergeMatmulModule and call merge_matmul on it.
 | |
|         module = SimpleMergeMatmulModule(rhs)
 | |
|         opt_module = merge_matmul.merge_matmul(module)
 | |
| 
 | |
|         # Numerical correctness check.
 | |
|         before = module(a, b)
 | |
|         after = opt_module(a, b)
 | |
|         before.allclose(after)
 | |
| 
 | |
|         # Basic graph structure check; original module should have 2 matmuls
 | |
|         # and optimized module should have 1.
 | |
|         self.assertEqual(_count_matmuls(module), 2)
 | |
|         self.assertEqual(_count_matmuls(opt_module), 1)
 | |
| 
 | |
|         # Test case in which there are multiple matmuls of different sizes to merge.
 | |
|         class FiveMergeMatmulModule(torch.nn.Module):
 | |
|             def __init__(self, rhs):
 | |
|                 super().__init__()
 | |
|                 self.rhs = rhs
 | |
| 
 | |
|             def forward(self, a, b, c, d, e):
 | |
|                 s = torch.tensor([])
 | |
|                 matmuls = []
 | |
| 
 | |
|                 # For some reason using a list comprehension or for-loop for this
 | |
|                 # doesn't work.
 | |
|                 matmuls.append(torch.matmul(a, self.rhs))
 | |
|                 matmuls.append(torch.matmul(b, self.rhs))
 | |
|                 matmuls.append(torch.matmul(c, self.rhs))
 | |
|                 matmuls.append(torch.matmul(d, self.rhs))
 | |
|                 matmuls.append(torch.matmul(e, self.rhs))
 | |
| 
 | |
|                 for m in matmuls:
 | |
|                     s += torch.sum(m)
 | |
| 
 | |
|                 return s
 | |
| 
 | |
|         # Initialize inputs.
 | |
|         inputs = [torch.randn(2 * i + 1, 5) for i in range(5)]
 | |
| 
 | |
|         # Initialize RHS.
 | |
|         rhs = torch.randn(5, 4)
 | |
| 
 | |
|         # Construct FiveMergeMatmulModule and call merge_matmul on it.
 | |
|         module = FiveMergeMatmulModule(rhs)
 | |
|         opt_module = merge_matmul.merge_matmul(module)
 | |
| 
 | |
|         # Numerical correctness check.
 | |
|         before = module(*inputs)
 | |
|         after = opt_module(*inputs)
 | |
|         before.allclose(after)
 | |
| 
 | |
|         # Basic graph structure check; original module should have len(inputs) matmuls
 | |
|         # and optimized module should have 1.
 | |
|         self.assertEqual(_count_matmuls(module), len(inputs))
 | |
|         self.assertEqual(_count_matmuls(opt_module), 1)
 | |
| 
 | |
|         # Simple test case in which two matmuls cannot be merged due to a data dependency between
 | |
|         # the LHS operands.
 | |
|         class UnmergeableMatmulModule(torch.nn.Module):
 | |
|             def __init__(self, rhs):
 | |
|                 super().__init__()
 | |
|                 self.rhs = rhs
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 a = torch.matmul(x, self.rhs)
 | |
|                 a_abs = torch.abs(a)
 | |
|                 b = torch.matmul(a_abs.transpose(1, 0), self.rhs)
 | |
|                 return b
 | |
| 
 | |
|         # Initialize inputs.
 | |
|         a = torch.randn(3, 3)
 | |
| 
 | |
|         # Initialize RHS for matmuls.
 | |
|         rhs = torch.randn(3, 4)
 | |
| 
 | |
|         # Construct UnmergeableMatmulModule and call merge_matmul on it.
 | |
|         module = UnmergeableMatmulModule(rhs)
 | |
|         opt_module = merge_matmul.merge_matmul(module)
 | |
| 
 | |
|         # Numerical correctness check.
 | |
|         before = module(a)
 | |
|         after = opt_module(a)
 | |
|         before.allclose(after)
 | |
| 
 | |
|         # Basic graph structure check; the number of matrix multiplcations should not have changed.
 | |
|         self.assertEqual(_count_matmuls(module), 2)
 | |
|         self.assertEqual(_count_matmuls(opt_module), 2)
 | |
| 
 | |
|     def test_type_matches(self):
 | |
|         should_be_equal = [
 | |
|             (int, int),
 | |
|             (numbers.Number, int),
 | |
|             (numbers.Number, float),
 | |
|             (int, type(torch.float)),
 | |
|             (Union[int, float], int),
 | |
|             (Union[int, float], float),
 | |
|             (list[int], int),
 | |
|             (list[int], create_type_hint([int, int])),
 | |
|             (list[int], create_type_hint((int, int))),
 | |
|             (list[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
 | |
|             (
 | |
|                 list[torch.Tensor],
 | |
|                 create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
 | |
|             ),
 | |
|             (torch.Tensor, torch.nn.Parameter),
 | |
|             (list[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
 | |
|             (list[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
 | |
|             (list[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
 | |
|             (
 | |
|                 list[torch.Tensor],
 | |
|                 create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
 | |
|             ),
 | |
|             (torch.Tensor, torch.nn.Parameter),
 | |
|             (list[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
 | |
|             (list[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
 | |
|             (Optional[list[torch.Tensor]], list[torch.Tensor]),
 | |
|             (Optional[list[int]], list[int]),
 | |
|         ] + [
 | |
|             # pre-PEP585 signatures
 | |
|             (typing.List[int], int),  # noqa: UP006
 | |
|             (typing.List[int], create_type_hint([int, int])),  # noqa: UP006
 | |
|             (typing.List[int], create_type_hint((int, int))),  # noqa: UP006
 | |
|             (typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),  # noqa: UP006
 | |
|             (
 | |
|                 typing.List[torch.Tensor],  # noqa: UP006
 | |
|                 create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
 | |
|             ),
 | |
|             (typing.List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),  # noqa: UP006
 | |
|             (typing.List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),  # noqa: UP006
 | |
|             (typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),  # noqa: UP006
 | |
|             (
 | |
|                 typing.List[torch.Tensor],  # noqa: UP006
 | |
|                 create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
 | |
|             ),
 | |
|             (typing.List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),  # noqa: UP006
 | |
|             (typing.List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),  # noqa: UP006
 | |
|             (Optional[typing.List[torch.Tensor]], typing.List[torch.Tensor]),  # noqa: UP006
 | |
|             (Optional[typing.List[int]], typing.List[int]),  # noqa: UP006
 | |
|         ]
 | |
| 
 | |
|         for sig_type, arg_type in should_be_equal:
 | |
|             self.assertTrue(type_matches(sig_type, arg_type))
 | |
| 
 | |
|         should_fail = [
 | |
|             (int, float),
 | |
|             (Union[int, float], str),
 | |
|             (list[torch.Tensor], typing.List[int]),  # noqa: UP006
 | |
|         ] + [
 | |
|             # pre-PEP585 signatures
 | |
|             (list[torch.Tensor], list[int]),
 | |
|         ]
 | |
| 
 | |
|         for sig_type, arg_type in should_fail:
 | |
|             self.assertFalse(type_matches(sig_type, arg_type))
 | |
| 
 | |
|     @skipIfNoMkldnn
 | |
|     def test_optimize_for_inference_cpu(self):
 | |
|         import torch.nn as nn
 | |
| 
 | |
|         class Foo(nn.Module):
 | |
|             def __init__(self) -> None:
 | |
|                 super().__init__()
 | |
|                 layers = []
 | |
|                 layers2 = []
 | |
|                 for _ in range(10):
 | |
|                     layers.append(nn.Conv2d(3, 3, 1))
 | |
|                     layers.append(nn.BatchNorm2d(3))
 | |
|                     layers.append(nn.ReLU())
 | |
| 
 | |
|                     layers2.append(nn.Conv2d(3, 3, 1))
 | |
|                     layers2.append(nn.BatchNorm2d(3))
 | |
|                     layers2.append(nn.ReLU())
 | |
|                 self.model = nn.Sequential(*layers)
 | |
|                 self.model2 = nn.Sequential(*layers2)
 | |
| 
 | |
|             def forward(self, x):
 | |
|                 return self.model(x) + self.model2(x)
 | |
| 
 | |
|         N, C, H, W, = (
 | |
|             1,
 | |
|             3,
 | |
|             224,
 | |
|             224,
 | |
|         )
 | |
|         inp = torch.randn(N, C, H, W)
 | |
|         with torch.no_grad():
 | |
|             model = Foo().eval()
 | |
|             optimized_model = optimization.optimize_for_inference(model)
 | |
|             torch.testing.assert_close(model(inp), optimized_model(inp))
 | |
| 
 | |
|             optimized_model2 = optimization.optimize_for_inference(
 | |
|                 model, pass_config={"remove_dropout": False}
 | |
|             )
 | |
|             torch.testing.assert_close(model(inp), optimized_model2(inp))
 | |
| 
 | |
|     @skipIfNoTorchVision
 | |
|     @skipIfNoMkldnn
 | |
|     def test_optimize_for_inference_cpu_torchvision(self):
 | |
|         models = [
 | |
|             torchvision.models.resnet18,
 | |
|             torchvision.models.resnet50,
 | |
|             torchvision.models.densenet121,
 | |
|             torchvision.models.shufflenet_v2_x1_0,
 | |
|             torchvision.models.vgg16,
 | |
|             torchvision.models.mobilenet_v2,
 | |
|             torchvision.models.mnasnet1_0,
 | |
|             torchvision.models.resnext50_32x4d,
 | |
|         ]
 | |
|         with torch.no_grad():
 | |
|             for model_type in models:
 | |
|                 model = model_type()
 | |
|                 C, H, W, = (
 | |
|                     3,
 | |
|                     224,
 | |
|                     224,
 | |
|                 )
 | |
|                 inp = torch.randn(3, C, H, W)
 | |
|                 model(inp)
 | |
|                 model.eval()
 | |
|                 inp = torch.randn(1, C, H, W)
 | |
|                 heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0)
 | |
|                 optimized_model = optimization.optimize_for_inference(model)
 | |
| 
 | |
|                 orig_out = model(inp)
 | |
|                 new_out = optimized_model(inp)
 | |
|                 torch.testing.assert_close(orig_out, new_out)
 | |
| 
 | |
| 
 | |
| class TestNormalizeOperators(JitTestCase):
 | |
|     @onlyCPU
 | |
|     @ops(op_db, allowed_dtypes=(torch.float,))
 | |
|     def test_normalize_operator_exhaustive(self, device, dtype, op):
 | |
|         # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
 | |
|         fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"}
 | |
|         sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
 | |
|         if isinstance(op.op, torch._ops.OpOverload):
 | |
|             self.skipTest("normalize operator doesn't work on torch.ops")
 | |
|         for sample_input in sample_inputs_itr:
 | |
|             unsupported_arg_type = False
 | |
|             arg_values = [sample_input.input] + list(sample_input.args)
 | |
|             kwarg_values = sample_input.kwargs
 | |
|             arg_types = []
 | |
|             kwarg_types = {}
 | |
| 
 | |
|             def jit_infer_type(v):
 | |
|                 inferred_arg_type = torch._C._jit_try_infer_type(v)
 | |
|                 assert inferred_arg_type.success()
 | |
|                 t = _torchscript_type_to_python_type(inferred_arg_type.type())
 | |
|                 return t
 | |
| 
 | |
|             for v in arg_values:
 | |
|                 if isinstance(v, torch.Tensor):
 | |
|                     arg_types.append(type(v))
 | |
|                 else:
 | |
|                     if isinstance(v, complex):
 | |
|                         # Complex type not supported in FX
 | |
|                         unsupported_arg_type = True
 | |
|                     arg_types.append(jit_infer_type(v))
 | |
| 
 | |
|             for k, v in kwarg_values.items():
 | |
|                 if isinstance(v, torch.Tensor):
 | |
|                     kwarg_types[k] = type(v)
 | |
|                 else:
 | |
|                     if isinstance(v, complex):
 | |
|                         # Complex type not supported in FX
 | |
|                         unsupported_arg_type = True
 | |
|                     kwarg_types[k] = jit_infer_type(v)
 | |
| 
 | |
|             if unsupported_arg_type:
 | |
|                 continue
 | |
|             # Test normalize_function by itself
 | |
|             ref_out = op.op(*arg_values, **kwarg_values)
 | |
|             norm_args_and_kwargs = normalize_function(
 | |
|                 op.op, arg_values, kwarg_values, arg_types, kwarg_types
 | |
|             )
 | |
|             if norm_args_and_kwargs is None:
 | |
|                 raise RuntimeError(
 | |
|                     """
 | |
|                     FX failed to normalize op - add the op to the op_skip list.
 | |
|                     A common reason is if your OpInfo was implemented with a lambda
 | |
|                     - otherwise, file an issue
 | |
|                     """
 | |
|                 )
 | |
|             test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs)
 | |
|             self.assertEqual(test_out, ref_out)
 | |
| 
 | |
|             # Test normalized_arguments as part of FX
 | |
|             if op.name in fx_fail:
 | |
|                 continue
 | |
|             param_names = []
 | |
|             param_values = []
 | |
|             fx_args = []
 | |
| 
 | |
|             idx = 0
 | |
| 
 | |
|             def process_arg(arg, name):
 | |
|                 if isinstance(arg, torch.Tensor):
 | |
|                     param_names.append(name)
 | |
|                     param_values.append(arg)
 | |
|                     return name
 | |
|                 else:
 | |
|                     return f"{repr(arg)}"
 | |
| 
 | |
|             def process_arg_with_idx(arg):
 | |
|                 nonlocal idx
 | |
|                 res = process_arg(arg, f"arg_{idx}")
 | |
|                 idx = idx + 1
 | |
|                 return res
 | |
| 
 | |
|             def str_arg(arg):
 | |
|                 if isinstance(arg, tuple):
 | |
|                     args = [f"{str_arg(v)}, " for v in arg]
 | |
|                     return f"({' '.join(args)})"
 | |
|                 elif isinstance(arg, list):
 | |
|                     args = [f"{str_arg(v)}" for v in arg]
 | |
|                     return f"[{', '.join(args)}]"
 | |
|                 else:
 | |
|                     return arg
 | |
| 
 | |
|             for v in arg_values:
 | |
|                 arg = pytree.tree_map(process_arg_with_idx, v)
 | |
|                 fx_args.append(str_arg(arg))
 | |
| 
 | |
|             for k, v in kwarg_values.items():
 | |
|                 arg = pytree.tree_map(functools.partial(process_arg, name=k), v)
 | |
|                 fx_args.append(f"{k} = {str_arg(arg)}")
 | |
| 
 | |
|             code = f"""
 | |
| class TestModule(torch.nn.Module):
 | |
|     def forward(self, {', '.join(param_names)}):
 | |
|         return torch.{op.name}({', '.join(fx_args)})
 | |
|             """
 | |
| 
 | |
|             g = {"torch": torch, "inf": math.inf}
 | |
|             exec(code, g)
 | |
|             TestModule = g["TestModule"]
 | |
| 
 | |
|             m = TestModule()
 | |
|             traced = torch.fx.symbolic_trace(m)
 | |
|             ref_out = traced(*param_values)
 | |
| 
 | |
|             for node in traced.graph.nodes:
 | |
|                 if node.op == "call_function":
 | |
|                     normalized_args = node.normalized_arguments(
 | |
|                         traced, arg_types, kwarg_types
 | |
|                     )
 | |
|                     assert normalized_args
 | |
|                     node.args = normalized_args.args
 | |
|                     node.kwargs = normalized_args.kwargs
 | |
|             traced.recompile()
 | |
| 
 | |
|             test_out = traced(*param_values)
 | |
|             self.assertEqual(test_out, ref_out)
 | |
| 
 | |
|     def test_normalize_quantized_eb(self):
 | |
|         target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
 | |
|         args = (
 | |
|             torch.empty((2, 3), dtype=torch.uint8),
 | |
|             torch.empty((2,), dtype=torch.int64),
 | |
|             torch.empty((2,), dtype=torch.int64),
 | |
|         )
 | |
|         norm_args_and_kwargs = normalize_function(
 | |
|             target, args, normalize_to_only_use_kwargs=True
 | |
|         )
 | |
|         self.assertTrue(norm_args_and_kwargs is not None)
 | |
|         self.assertEqual(
 | |
|             set(norm_args_and_kwargs.kwargs.keys()),
 | |
|             {
 | |
|                 "weight",
 | |
|                 "indices",
 | |
|                 "offsets",
 | |
|                 "scale_grad_by_freq",
 | |
|                 "mode",
 | |
|                 "pruned_weights",
 | |
|                 "per_sample_weights",
 | |
|                 "compressed_indices_mapping",
 | |
|                 "include_last_offset",
 | |
|             },
 | |
|         )
 | |
|         self.assertEqual(norm_args_and_kwargs.args, ())
 | |
| 
 | |
|     def test_normalize_args_op_overload(self):
 | |
|         for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]:
 | |
|             inp1 = torch.rand([1])
 | |
|             inp2 = torch.rand([4])
 | |
|             args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True)
 | |
|             self.assertIs(kwargs["input"], inp1)
 | |
|             self.assertIs(kwargs["the_template"], inp2)
 | |
| 
 | |
| 
 | |
| if TEST_Z3:
 | |
|     import z3
 | |
| 
 | |
|     import torch._dynamo.config
 | |
| 
 | |
|     from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
 | |
|     from torch.utils._sympy.functions import FloorDiv, Mod, BitwiseFn_bitwise_and
 | |
| 
 | |
|     class TestTranslationValidation(TestCase):
 | |
|         def _prepare_for_translation_validation(self):
 | |
|             validator = TranslationValidator()
 | |
| 
 | |
|             # SymPy symbols.
 | |
|             s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
 | |
| 
 | |
|             # Z3 symbols.
 | |
|             [validator.add_var(s, int) for s in (s0, s1, s2)]
 | |
|             z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
 | |
| 
 | |
|             return (s0, s1, s2), (z0, z1, z2), validator
 | |
| 
 | |
|         def test_sympy_to_z3(self):
 | |
| 
 | |
|             (
 | |
|                 (s0, s1, s2),
 | |
|                 (z0, z1, z2),
 | |
|                 validator,
 | |
|             ) = self._prepare_for_translation_validation()
 | |
| 
 | |
|             test_cases = [
 | |
|                 # Integer constants.
 | |
|                 (sympy.S.Zero, z3.IntVal(0)),
 | |
|                 (sympy.S.One, z3.IntVal(1)),
 | |
|                 (sympy.S.NegativeOne, z3.IntVal(-1)),
 | |
|                 (sympy.Integer(2), z3.IntVal(2)),
 | |
|                 (
 | |
|                     s0,
 | |
|                     z0,
 | |
|                 ),
 | |
|                 # Arithmetic operations.
 | |
|                 *[
 | |
|                     (op(s0, s1), op(z0, z1))
 | |
|                     for op in (
 | |
|                         operator.add,
 | |
|                         operator.mul,
 | |
|                         operator.pow,
 | |
|                     )
 | |
|                 ],
 | |
|                 # Logical operations.
 | |
|                 *[
 | |
|                     (sympy_op(s0, s1), z3_op(z0, z1))
 | |
|                     for sympy_op, z3_op in (
 | |
|                         (sympy.Eq, operator.eq),
 | |
|                         (sympy.Ne, operator.ne),
 | |
|                         (sympy.Lt, operator.lt),
 | |
|                         (sympy.Le, operator.le),
 | |
|                         (sympy.Gt, operator.gt),
 | |
|                         (sympy.Ge, operator.ge),
 | |
|                     )
 | |
|                 ],
 | |
|                 # Bitwise operations.
 | |
|                 (BitwiseFn_bitwise_and(s0, s1), z3.BV2Int(z3.Int2BV(z0, 64) & z3.Int2BV(z1, 64))),
 | |
|                 # Other operations.
 | |
|                 (
 | |
|                     s0 - s1,
 | |
|                     z0 + z3.IntVal(-1) * z1,
 | |
|                 ),
 | |
|                 (
 | |
|                     s0 / s1,
 | |
|                     z3.ToReal(z0) * (z1**-1),
 | |
|                 ),
 | |
|                 (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
 | |
|                 (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
 | |
|                 (
 | |
|                     Mod(s2, (s0 / s1)),
 | |
|                     z2
 | |
|                     - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
 | |
|                     * (z3.ToReal(z0) * z1**-1),
 | |
|                 ),
 | |
|                 (
 | |
|                     Mod(s2, s0**3),
 | |
|                     z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
 | |
|                 ),
 | |
|             ]
 | |
| 
 | |
|             toZ3 = SympyToZ3(validator)
 | |
|             for sympy_expr, z3_expr in test_cases:
 | |
|                 result = toZ3.run(sympy_expr)
 | |
|                 self.assertTrue(
 | |
|                     z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
 | |
|                 )
 | |
| 
 | |
|         def test_sat(self):
 | |
|             (
 | |
|                 (s0, s1, s2),
 | |
|                 (z0, z1, z2),
 | |
|                 validator,
 | |
|             ) = self._prepare_for_translation_validation()
 | |
| 
 | |
|             validator.add_source_expr(z0 > 5)
 | |
|             validator.add_source_expr(z1 / 2 > z0)
 | |
| 
 | |
|             # Solutions for target is a subset of the solutions for the source.
 | |
|             validator.add_target_expr(s0 > 20)
 | |
|             validator.add_target_expr(s1 > s0**2)
 | |
| 
 | |
|             validator.validate()
 | |
| 
 | |
|         def test_sat_bitwise(self):
 | |
|             (
 | |
|                 (s0, s1, s2),
 | |
|                 (z0, z1, z2),
 | |
|                 validator,
 | |
|             ) = self._prepare_for_translation_validation()
 | |
| 
 | |
|             validator.add_source_expr(z3.BV2Int(z3.Int2BV(z0, 64) & z3.Int2BV(z1, 64)) == 5)
 | |
|             validator.add_source_expr(z0 == 0b110101)
 | |
| 
 | |
|             validator.validate()
 | |
| 
 | |
|         def test_unsat(self):
 | |
|             (
 | |
|                 (s0, s1, s2),
 | |
|                 (z0, z1, z2),
 | |
|                 validator,
 | |
|             ) = self._prepare_for_translation_validation()
 | |
| 
 | |
|             validator.add_source_expr(z0 > 5)
 | |
|             validator.add_source_expr(z1 / 2 > z0)
 | |
| 
 | |
|             # Solutions for target is NOT a subset of the solutions for the source.
 | |
|             validator.add_target_expr(s0 > 20)
 | |
|             # This expression is less restrictive than its counterpart.
 | |
|             validator.add_target_expr(s1 > s0 + 2)
 | |
| 
 | |
|             with self.assertRaisesRegex(ValidationException, "translation validation failed."):
 | |
|                 validator.validate()
 | |
| 
 | |
|         def test_z3str(self):
 | |
|             a = z3.Int("a")
 | |
|             b = z3.Int("b")
 | |
|             special = z3.Real("this.size()[2]")
 | |
| 
 | |
|             test_cases = [
 | |
|                 (z3.IntVal(42), "42"),
 | |
|                 # Variable.
 | |
|                 (a, "a"),
 | |
|                 # Name with special characters.
 | |
|                 (special, "this.size()[2]"),
 | |
|                 # Renamed function fpplications.
 | |
|                 (a != b, "(!= a b)"),
 | |
|                 (a ** b, "(pow a b)"),
 | |
|                 # Chain of associative operations.
 | |
|                 *[
 | |
|                     (op(op(a, 5), b), f"({opstr} 5 a b)")
 | |
|                     for op, opstr in [
 | |
|                         (operator.add, "+"),
 | |
|                         (operator.mul, "*")
 | |
|                     ]
 | |
|                 ],
 | |
|                 # Revert 'Not' conversions.
 | |
|                 (a != b, "(!= a b)"),
 | |
|                 (a < b, "(> b a)"),
 | |
|                 (a > b, "(> a b)"),
 | |
|                 # Ignore 'ToInt' and 'ToReal' functions.
 | |
|                 (z3.ToInt(special) + a, "(+ this.size()[2] a)"),
 | |
|                 (z3.ToReal(a + b), "(+ a b)"),
 | |
|                 # Convert to floor division: 'idiv'.
 | |
|                 (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
 | |
|             ]
 | |
| 
 | |
|             for expr, expected in test_cases:
 | |
|                 self.assertEqual(z3str(expr), expected)
 | |
| 
 | |
| 
 | |
| instantiate_device_type_tests(TestNormalizeOperators, globals())
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     run_tests()
 |