mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: as title Test Plan: ci Rollback Plan: Reviewed By: yushangdi Differential Revision: D78296410 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158322 Approved by: https://github.com/yushangdi
273 lines
9.8 KiB
Python
273 lines
9.8 KiB
Python
# Owner(s): ["module: fx"]
|
|
|
|
import torch
|
|
from torch._inductor.compile_fx import aot_export_module
|
|
from torch.export import default_decompositions
|
|
from torch.fx.traceback import get_graph_provenance_json, NodeSource, NodeSourceAction
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
|
|
CREATE_STR = NodeSourceAction.CREATE.name.lower()
|
|
|
|
|
|
class TestFXNodeSource(TestCase):
|
|
def test_node_source(self):
|
|
node_source = NodeSource(
|
|
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
|
|
)
|
|
self.assertExpectedInline(
|
|
node_source.print_readable().strip(),
|
|
"""(name=, pass_name=test_pass, action=create, graph_id=-1)""",
|
|
)
|
|
dummy_source_dict = {
|
|
"name": "",
|
|
"target": "",
|
|
"pass_name": "test_pass",
|
|
"action": CREATE_STR,
|
|
"graph_id": -1,
|
|
"from_node": [],
|
|
}
|
|
self.assertEqual(
|
|
node_source.to_dict(),
|
|
dummy_source_dict,
|
|
)
|
|
|
|
# Dummy node
|
|
node = torch.fx.Node(
|
|
graph=torch.fx.Graph(),
|
|
name="add",
|
|
op="call_function",
|
|
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
|
|
args=(torch.tensor(3), torch.tensor(4)),
|
|
kwargs={},
|
|
)
|
|
node.meta["from_node"] = [node_source]
|
|
|
|
graph_id = id(node.graph)
|
|
node_source = NodeSource(
|
|
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
|
|
)
|
|
self.assertExpectedInline(
|
|
node_source.print_readable().strip(),
|
|
f"""\
|
|
(name=add, pass_name=test_pass, action=create, graph_id={graph_id})
|
|
(name=, pass_name=test_pass, action=create, graph_id=-1)""",
|
|
)
|
|
self.assertEqual(
|
|
node_source.to_dict(),
|
|
{
|
|
"name": "add",
|
|
"target": "aten.add.Tensor",
|
|
"pass_name": "test_pass",
|
|
"action": CREATE_STR,
|
|
"graph_id": graph_id,
|
|
"from_node": [dummy_source_dict],
|
|
},
|
|
)
|
|
|
|
# Test two node sources are same
|
|
node_source1 = NodeSource(
|
|
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
|
|
)
|
|
node_source2 = NodeSource(
|
|
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
|
|
)
|
|
self.assertEqual(node_source1, node_source2)
|
|
|
|
# Test hash function - equivalent objects should have same hash
|
|
self.assertEqual(hash(node_source1), hash(node_source2))
|
|
|
|
# Test two node sources are not same
|
|
node_source3 = NodeSource(
|
|
node=None, pass_name="test_pass_1", action=NodeSourceAction.CREATE
|
|
)
|
|
node_source4 = NodeSource(
|
|
node=None, pass_name="test_pass_2", action=NodeSourceAction.CREATE
|
|
)
|
|
self.assertNotEqual(node_source3, node_source4)
|
|
|
|
# Test hash function - different objects should have different hash
|
|
self.assertNotEqual(hash(node_source3), hash(node_source4))
|
|
|
|
# Test that equivalent NodeSource objects can be used in sets and dicts
|
|
node_set = {node_source1, node_source2}
|
|
self.assertEqual(len(node_set), 1) # Should only contain one unique element
|
|
|
|
node_dict = {node_source1: "value1", node_source2: "value2"}
|
|
self.assertEqual(len(node_dict), 1) # Should only contain one key
|
|
self.assertEqual(node_dict[node_source1], "value2") # Last value should win
|
|
|
|
# Test with more complex NodeSource objects
|
|
node_source_with_node = NodeSource(
|
|
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
|
|
)
|
|
node_source_with_node_copy = NodeSource(
|
|
node=node, pass_name="test_pass", action=NodeSourceAction.CREATE
|
|
)
|
|
|
|
# These should be equal and have same hash
|
|
self.assertEqual(node_source_with_node, node_source_with_node_copy)
|
|
self.assertEqual(hash(node_source_with_node), hash(node_source_with_node_copy))
|
|
|
|
# Test with different actions
|
|
node_source_replace = NodeSource(
|
|
node=None, pass_name="test_pass", action=NodeSourceAction.REPLACE
|
|
)
|
|
node_source_create = NodeSource(
|
|
node=None, pass_name="test_pass", action=NodeSourceAction.CREATE
|
|
)
|
|
|
|
# These should be different and have different hashes
|
|
self.assertNotEqual(node_source_replace, node_source_create)
|
|
self.assertNotEqual(hash(node_source_replace), hash(node_source_create))
|
|
|
|
def test_graph_provenance(self):
|
|
def check_node_source(node_source_dict, name, pass_name, action):
|
|
self.assertEqual(node_source_dict["name"], name)
|
|
self.assertEqual(node_source_dict["pass_name"], pass_name)
|
|
self.assertEqual(node_source_dict["action"], action)
|
|
|
|
def get_first_node_source_and_check(node_source_dict):
|
|
"""
|
|
Get the first node source from the from_node list.
|
|
"""
|
|
self.assertEqual(len(node_source_dict["from_node"]), 1)
|
|
return node_source_dict["from_node"][0]
|
|
|
|
class Model(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = torch.nn.Linear(10, 16)
|
|
self.relu = torch.nn.ReLU()
|
|
self.fc2 = torch.nn.Linear(16, 1)
|
|
self.sigmoid = torch.nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.fc2(x)
|
|
x = self.sigmoid(x)
|
|
return (x,)
|
|
|
|
model = Model()
|
|
example_inputs = (torch.randn(8, 10),)
|
|
ep = torch.export.export(model, example_inputs, strict=True)
|
|
|
|
decomposed_ep = ep.run_decompositions(default_decompositions())
|
|
# node decomposed from same ancestor node should have same from_node info
|
|
for node in decomposed_ep.graph.nodes:
|
|
if node.op not in {"placeholder", "output"}:
|
|
assert "from_node" in node.meta
|
|
|
|
node_name_to_from_node = {
|
|
node.name: node.meta["from_node"]
|
|
for node in decomposed_ep.graph.nodes
|
|
if node.op not in {"placeholder", "output"}
|
|
}
|
|
same_ancestor_nodes = {
|
|
"permute": "addmm",
|
|
"addmm": "permute",
|
|
"permute_1": "addmm_1",
|
|
"addmm_1": "permute_1",
|
|
}
|
|
|
|
for node_name_1 in node_name_to_from_node:
|
|
for node_name_2 in node_name_to_from_node:
|
|
if node_name_2 in {
|
|
node_name_1,
|
|
same_ancestor_nodes[node_name_1]
|
|
if node_name_1 in same_ancestor_nodes
|
|
else None,
|
|
}:
|
|
self.assertTrue(
|
|
node_name_to_from_node[node_name_1]
|
|
== node_name_to_from_node[node_name_2]
|
|
)
|
|
else:
|
|
self.assertTrue(
|
|
node_name_to_from_node[node_name_1]
|
|
!= node_name_to_from_node[node_name_2]
|
|
)
|
|
|
|
gm = ep.module()
|
|
provenance = get_graph_provenance_json(gm.graph)
|
|
self.assertEqual(
|
|
set(provenance.keys()), {"relu", "linear", "sigmoid", "linear_1"}
|
|
)
|
|
|
|
# Check node "linear" is created from node "x" in PropagateUnbackedSymInts
|
|
key_provenance = provenance["linear"][0]["from_node"]
|
|
self.assertEqual(len(key_provenance), 1)
|
|
key_provenance = key_provenance[0]
|
|
check_node_source(
|
|
key_provenance,
|
|
"x",
|
|
"Interpreter_PropagateUnbackedSymInts",
|
|
CREATE_STR,
|
|
)
|
|
|
|
# Check node "x" is then created from another node "x" in FlattenInputOutputSignature
|
|
key_provenance = get_first_node_source_and_check(key_provenance)
|
|
check_node_source(
|
|
key_provenance,
|
|
"x",
|
|
"Interpreter_FlattenInputOutputSignature",
|
|
CREATE_STR,
|
|
)
|
|
|
|
gm, graph_signature = aot_export_module(
|
|
gm,
|
|
example_inputs,
|
|
trace_joint=False,
|
|
)
|
|
|
|
provenance = get_graph_provenance_json(gm.graph)
|
|
|
|
self.assertEqual(
|
|
set(provenance.keys()), {"t", "addmm", "relu", "t_1", "addmm_1", "sigmoid"}
|
|
)
|
|
for key in ["t", "addmm"]:
|
|
# The node provenance hierarchy should be:
|
|
# t -> linear -> x -> x
|
|
#
|
|
# x -> y means x is created from y
|
|
|
|
key_provenance = provenance[key]
|
|
self.assertEqual(len(key_provenance), 1)
|
|
key_provenance = key_provenance[0]
|
|
|
|
# Check node "t" and "addmm" is created from node "linear" in PropagateUnbackedSymInts
|
|
check_node_source(
|
|
key_provenance,
|
|
"linear",
|
|
"Interpreter_PropagateUnbackedSymInts",
|
|
CREATE_STR,
|
|
)
|
|
|
|
# Check node "linear" is then created from node "x" in PropagateUnbackedSymInts
|
|
key_provenance = get_first_node_source_and_check(key_provenance)[
|
|
"from_node"
|
|
][0]
|
|
check_node_source(
|
|
key_provenance,
|
|
"x",
|
|
"Interpreter_PropagateUnbackedSymInts",
|
|
CREATE_STR,
|
|
)
|
|
|
|
# Check node "x" is then created from another node "x" in FlattenInputOutputSignature
|
|
key_provenance = get_first_node_source_and_check(key_provenance)
|
|
check_node_source(
|
|
key_provenance,
|
|
"x",
|
|
"Interpreter_FlattenInputOutputSignature",
|
|
CREATE_STR,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise RuntimeError(
|
|
"This test is not currently used and should be "
|
|
"enabled in discover_tests.py if required."
|
|
)
|