mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: replace export_for_training with epxort Test Plan: CI Rollback Plan: Differential Revision: D81935792 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162396 Approved by: https://github.com/angelayi, https://github.com/jerryzh168
274 lines
11 KiB
Python
274 lines
11 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
import os
|
|
import sys
|
|
from typing import Callable
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.export import export
|
|
from torch.fx import symbolic_trace
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
|
|
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
import unittest
|
|
|
|
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
|
|
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
|
|
SubgraphMatcherWithNameNodeMap,
|
|
)
|
|
from torch.testing._internal.common_utils import IS_WINDOWS
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
|
|
class WrapperModule(torch.nn.Module):
|
|
def __init__(self, fn: Callable):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return self.fn(*args, **kwargs)
|
|
|
|
|
|
class TestMatcher(JitTestCase):
|
|
def test_subgraph_matcher_with_attributes(self):
|
|
class LargeModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._weight = torch.nn.Parameter(torch.ones(3, 3))
|
|
self._bias = torch.nn.Parameter(torch.ones(3, 3))
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.addmm.default(self._bias, x, self._weight)
|
|
|
|
# Large Model graph:
|
|
# opcode name target args kwargs
|
|
# ------------- ------------- ------------------ ------------------- --------
|
|
# placeholder x x () {}
|
|
# get_attr _bias _bias () {}
|
|
# get_attr _weight _weight () {}
|
|
# call_function addmm_default aten.addmm.default (_bias, x, _weight) {}
|
|
# output output output (addmm_default,) {}
|
|
large_model_graph = symbolic_trace(LargeModel()).graph
|
|
|
|
class PatternModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self._weight_1 = torch.nn.Parameter(torch.ones(5, 5))
|
|
self._bias_1 = torch.nn.Parameter(torch.ones(5, 5))
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.addmm.default(self._bias_1, x, self._weight_1)
|
|
|
|
pattern_graph = torch.fx.symbolic_trace(PatternModel()).graph
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
|
match_result = subgraph_matcher.match(large_model_graph)
|
|
self.assertEqual(len(match_result), 1)
|
|
|
|
def test_subgraph_matcher_with_list(self):
|
|
def original(x, y):
|
|
return torch.ops.aten.view(x, [5, y.shape[0]])
|
|
|
|
original_graph = torch.fx.symbolic_trace(original).graph
|
|
|
|
def pattern(x, y, z):
|
|
return torch.ops.aten.view(x, [z, y.shape[0]])
|
|
|
|
pattern_graph = torch.fx.symbolic_trace(pattern).graph
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
|
match_result = subgraph_matcher.match(original_graph)
|
|
self.assertEqual(len(match_result), 1)
|
|
|
|
def test_subgraph_matcher_with_list_bad(self):
|
|
def original(x, y):
|
|
return torch.ops.aten._reshape_alias_copy.default(
|
|
x, [1, y.shape[0]], [y.shape[1], y.shape[1]]
|
|
)
|
|
|
|
original_graph = torch.fx.symbolic_trace(original).graph
|
|
|
|
def pattern(x, y, b):
|
|
return torch.ops.aten._reshape_alias_copy.default(
|
|
x, [b, y.shape[0], y.shape[1]], [y.shape[1]]
|
|
)
|
|
|
|
pattern_graph = torch.fx.symbolic_trace(pattern).graph
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
|
match_result = subgraph_matcher.match(original_graph)
|
|
self.assertEqual(len(match_result), 0)
|
|
|
|
def test_subgraph_matcher_ignore_literals(self):
|
|
def original(x):
|
|
return x + 1
|
|
|
|
original_graph = make_fx(original)(torch.ones(3, 3)).graph
|
|
original_graph.eliminate_dead_code()
|
|
|
|
def pattern(x):
|
|
return x + 2
|
|
|
|
pattern_graph = make_fx(pattern)(torch.ones(4, 4)).graph
|
|
pattern_graph.eliminate_dead_code()
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph)
|
|
match_result = subgraph_matcher.match(original_graph)
|
|
self.assertEqual(len(match_result), 0)
|
|
|
|
subgraph_matcher = SubgraphMatcher(pattern_graph, ignore_literals=True)
|
|
match_result = subgraph_matcher.match(original_graph)
|
|
self.assertEqual(len(match_result), 1)
|
|
|
|
def test_variatic_arg_matching(self):
|
|
inputs = (torch.randn(20, 16, 50, 32),)
|
|
|
|
def maxpool(x, kernel_size, stride, padding, dilation):
|
|
return torch.ops.aten.max_pool2d_with_indices.default(
|
|
x, kernel_size, stride, padding, dilation
|
|
)
|
|
|
|
maxpool_graph = torch.fx.symbolic_trace(maxpool).graph
|
|
|
|
maxpool_matcher = SubgraphMatcher(maxpool_graph)
|
|
match_result = maxpool_matcher.match(maxpool_graph)
|
|
self.assertEqual(len(match_result), 1)
|
|
|
|
# Graph only contains "stride" argument
|
|
maxpool_s = torch.nn.MaxPool2d(kernel_size=2, stride=1).eval()
|
|
maxpool_s_graph = make_fx(maxpool_s)(*inputs).graph
|
|
match_s_result = maxpool_matcher.match(maxpool_s_graph)
|
|
self.assertEqual(len(match_s_result), 1)
|
|
|
|
# Graph only contains "padding" argument
|
|
maxpool_p = torch.nn.MaxPool2d(kernel_size=2, padding=1)
|
|
maxpool_p_graph = make_fx(maxpool_p)(*inputs).graph
|
|
match_p_result = maxpool_matcher.match(maxpool_p_graph)
|
|
self.assertEqual(len(match_p_result), 1)
|
|
|
|
# Graph only contains "stride, padding" argument
|
|
maxpool_sp = torch.nn.MaxPool2d(kernel_size=2, stride=1, padding=1)
|
|
maxpool_sp_graph = make_fx(maxpool_sp)(*inputs).graph
|
|
match_sp_result = maxpool_matcher.match(maxpool_sp_graph)
|
|
self.assertEqual(len(match_sp_result), 1)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
|
def test_split_to_graph_and_name_node_map(self):
|
|
"""Testing the internal helper function for splitting the pattern graph"""
|
|
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
|
|
_split_to_graph_and_name_node_map,
|
|
)
|
|
|
|
def pattern(x, weight):
|
|
conv = F.conv2d(x, weight)
|
|
relu = F.relu(conv)
|
|
relu_mul_by_two = relu * 2
|
|
return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 3, 3, 3) * 10,
|
|
torch.randn(3, 3, 3, 3),
|
|
)
|
|
pattern_gm = export(
|
|
WrapperModule(pattern), example_inputs, strict=True
|
|
).module()
|
|
before_split_res = pattern_gm(*example_inputs)
|
|
pattern_gm, _ = _split_to_graph_and_name_node_map(pattern_gm)
|
|
after_split_res = pattern_gm(*example_inputs)
|
|
self.assertEqual(before_split_res[0], after_split_res[0])
|
|
self.assertEqual(before_split_res[1], after_split_res[1])
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
|
def test_matcher_with_name_node_map_function(self):
|
|
"""Testing SubgraphMatcherWithNameNodeMap with function pattern"""
|
|
|
|
def target_graph(x, weight):
|
|
x = x * 2
|
|
weight = weight * 3
|
|
conv = F.conv2d(x, weight)
|
|
relu = F.relu(conv)
|
|
relu2 = relu * 2
|
|
return relu + relu2
|
|
|
|
def pattern(x, weight):
|
|
conv = F.conv2d(x, weight)
|
|
relu = F.relu(conv)
|
|
relu_mul_by_two = relu * 2
|
|
return relu, relu_mul_by_two, {"conv": conv, "relu": relu}
|
|
|
|
example_inputs = (
|
|
torch.randn(1, 3, 3, 3) * 10,
|
|
torch.randn(3, 3, 3, 3),
|
|
)
|
|
pattern_gm = export(
|
|
WrapperModule(pattern), example_inputs, strict=True
|
|
).module()
|
|
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
|
target_gm = export(
|
|
WrapperModule(target_graph), example_inputs, strict=True
|
|
).module()
|
|
internal_matches = matcher.match(target_gm.graph)
|
|
for internal_match in internal_matches:
|
|
name_node_map = internal_match.name_node_map
|
|
assert "conv" in name_node_map
|
|
assert "relu" in name_node_map
|
|
name_node_map["conv"].meta["custom_annotation"] = "annotation"
|
|
# check if we correctly annotated the target graph module
|
|
for n in target_gm.graph.nodes:
|
|
if n == name_node_map["conv"]:
|
|
assert (
|
|
"custom_annotation" in n.meta
|
|
and n.meta["custom_annotation"] == "annotation"
|
|
)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
|
def test_matcher_with_name_node_map_module(self):
|
|
"""Testing SubgraphMatcherWithNameNodeMap with module pattern"""
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class Pattern(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(5, 5)
|
|
|
|
def forward(self, x):
|
|
linear = self.linear(x)
|
|
# Note: we can't put "weight": self.linear.weight in dictionary since
|
|
# nn.Parameter is not an allowed output type in dynamo
|
|
return linear, {"linear": linear, "x": x}
|
|
|
|
example_inputs = (torch.randn(3, 5),)
|
|
pattern_gm = export(Pattern(), example_inputs, strict=True).module()
|
|
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
|
target_gm = export(M(), example_inputs, strict=True).module()
|
|
internal_matches = matcher.match(target_gm.graph)
|
|
for internal_match in internal_matches:
|
|
name_node_map = internal_match.name_node_map
|
|
assert "linear" in name_node_map
|
|
assert "x" in name_node_map
|
|
name_node_map["linear"].meta["custom_annotation"] = "annotation"
|
|
# check if we correctly annotated the target graph module
|
|
for n in target_gm.graph.nodes:
|
|
if n == name_node_map["linear"]:
|
|
assert (
|
|
"custom_annotation" in n.meta
|
|
and n.meta["custom_annotation"] == "annotation"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test is not currently used and should be "
|
|
"enabled in discover_tests.py if required."
|
|
)
|