Files
pytorch/test/fx/test_fx_split.py
Yuanyuan Chen 8de85896e0 Enable ruff rule E721 (#165162)
`E721` checks for object type comparisons using == and other comparison operators. This is useful because it is recommended to use `is` for type comparisons.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165162
Approved by: https://github.com/Skylion007
2025-10-13 01:48:55 +00:00

308 lines
10 KiB
Python

# Owner(s): ["module: fx"]
import dataclasses
from collections import defaultdict
import torch
import torch.fx.passes.operator_support as op_support
import torch.fx.passes.splitter_base as splitter_base
from torch.fx.passes.split_utils import split_by_tags
from torch.testing._internal.common_utils import TestCase
@torch.jit.script
@dataclasses.dataclass
class DummyDataClass:
a: int
b: int
c: int
@torch.fx.wrap
def wrapped_add(_dataclass, y):
return _dataclass.c + y
class TestFXSplit(TestCase):
def test_split_preserve_node_meta(self):
class TestModule(torch.nn.Module):
def forward(self, x, y):
x = x + x
y = y * y
return x - y
gm = torch.fx.symbolic_trace(TestModule())
for node in gm.graph.nodes:
node.meta["name"] = node.name
if node.name == "add":
node.tag = "a"
elif node.name == "mul":
node.tag = "b"
elif node.name == "sub":
node.tag = "c"
split_gm = split_by_tags(gm, ["a", "b", "c"])
for m in split_gm.children():
for n in m.graph.nodes:
if n.op != "output":
self.assertIn("name", n.meta)
self.assertEqual(n.meta["name"], n.name)
# Validate that metadata is copied correctly for graph placeholder nodes
for node in split_gm.graph.nodes:
if node.op == "placeholder":
self.assertIn("name", node.meta)
self.assertEqual(node.meta["name"], node.name)
def test_dataclass_as_graph_entry(self):
"""
Test that splitting works when the graph entry is a dataclass instance
and a wrapped function is called with it, resulting in a call_function
node with no input dependencies. This tests the edge case fixed in D81232435
where call_function nodes with no dependencies should be handled properly
in the starter_nodes() method.
Graph visualization:
y (input) DummyDataClass(2,3,4) (no input deps, result as a call_function_node)
\ /
\ /
wrapped_add
|
z (output)
""" # noqa: W605
class TestModuleWithFunctionEntry(torch.nn.Module):
def forward(self, y):
# This creates a call_function node with no input dependencies
dummy_data_class = DummyDataClass(2, 3, 4)
z = wrapped_add(dummy_data_class, y)
return z
mod = TestModuleWithFunctionEntry()
gm = torch.fx.symbolic_trace(mod)
# Create custom operator support to mark wrapped_add as supported
class CustomOpSupport(op_support.OperatorSupportBase):
def is_node_supported(self, submodules, node) -> bool:
return node.target == wrapped_add
# Create a simple splitter to test the edge case
class TestSplitter(splitter_base._SplitterBase):
def __init__(self, module, sample_input, operator_support):
settings = splitter_base._SplitterSettingBase()
super().__init__(module, sample_input, operator_support, settings)
# Create splitter instance - this tests the fix where call_function nodes
# with no input dependencies are properly handled in starter_nodes()
splitter = TestSplitter(
module=gm,
sample_input=[torch.randn(2, 3)],
operator_support=CustomOpSupport(),
)
# This should not raise an exception (tests the fix from D81232435)
# The fix allows call_function nodes with no dependencies as valid starter nodes
split_result = splitter()
# Verify the splitting worked correctly
self.assertIsNotNone(split_result)
# Test that the split module produces the same result as the original
test_input = torch.randn(2, 3)
original_result = mod(test_input)
split_module_result = split_result(test_input)
self.assertTrue(torch.equal(original_result, split_module_result))
class TestSplitByTags(TestCase):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(2, 3)
self.linear2 = torch.nn.Linear(4, 5)
self.linear3 = torch.nn.Linear(6, 7)
self.linear4 = torch.nn.Linear(8, 6)
def forward(
self,
x1: torch.Tensor,
x2: torch.Tensor,
x3: torch.Tensor,
) -> torch.Tensor:
v1 = self.linear1(x1)
v2 = self.linear2(x2)
v3 = self.linear3(x3)
v4 = torch.cat([v1, v2, v3])
return self.linear4(v4)
@staticmethod
def trace_and_tag(
module: torch.nn.Module, tags: list[str]
) -> tuple[torch.fx.GraphModule, dict[str, list[str]]]:
"""
Test simple gm consists of nodes with tag (only show call_module nodes here):
linear1 - tag: "red"
linear2 - tag: "blue"
linear3, linear4 - tag: "green"
At the beginning we have:
gm:
linear1
linear2
linear3
linear4
split_gm = split_by_tags(gm, tags)
Then we have:
split_gm:
red:
linear1
blue:
linear2
green:
linear3
linear4
"""
tag_node = defaultdict(list)
gm: torch.fx.GraphModule = torch.fx.symbolic_trace(module)
# Add tag to all nodes and build dictionary record tag to call_module nodes
for node in gm.graph.nodes:
if "linear1" in node.name:
node.tag = tags[0]
tag_node[tags[0]].append(node.name)
elif "linear2" in node.name:
node.tag = tags[1]
tag_node[tags[1]].append(node.name)
else:
node.tag = tags[2]
if node.op == "call_module":
tag_node[tags[2]].append(node.name)
return gm, tag_node
def test_split_by_tags(self) -> None:
tags = ["red", "blue", "green"]
module = TestSplitByTags.TestModule()
gm, tag_node = TestSplitByTags.trace_and_tag(module, tags)
split_gm, orig_to_split_fqn_mapping = split_by_tags(
gm, tags, return_fqn_mapping=True
)
# Ensure split_gm has (and only has) ordered submodules named
# red_0, blue_1, green_2
for idx, (name, _) in enumerate(split_gm.named_children()):
if idx < len(tags):
self.assertTrue(
name == tags[idx],
f"split_gm has an incorrect submodule named {name}",
)
# Ensure each submodule has expected (ordered) call_module node(s).
# For example, a submodule named split_gm.red_0 has (and only has) linear1;
# split_gm.green_2 has (and only has) linear3 and linear4 with order
sub_graph_idx = 0
for sub_name, sub_graph_module in split_gm.named_children():
node_idx = 0
for node in sub_graph_module.graph.nodes:
if node.op != "call_module":
continue
self.assertTrue(
node.name == tag_node[f"{sub_name}"][node_idx],
# pyre-fixme[61]: `name` is undefined, or not always defined.
f"{sub_name} has incorrectly include {node.name}",
)
node_idx += 1
sub_graph_idx += 1
self.assertEqual(
orig_to_split_fqn_mapping,
{
"linear1": "red.linear1",
"linear2": "blue.linear2",
"linear3": "green.linear3",
"linear4": "green.linear4",
},
f"{orig_to_split_fqn_mapping=}",
)
class TestSplitOutputType(TestCase):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
conv = self.conv(x)
conv = conv * 0.5
relu = self.relu(conv)
return relu
@staticmethod
def trace_and_tag(
module: torch.nn.Module, inputs: torch.Tensor, tags: list[str]
) -> tuple[torch.fx.GraphModule, dict[str, list[str]]]:
"""
Test simple gm consists of nodes with tag (only show call_module nodes here):
conv - tag: "red"
mul - tag: "blue"
relu - tag: "green"
At the beginning we have:
gm:
conv
mul
relu
split_gm = split_by_tags(gm, tags)
Then we have:
split_gm:
red:
conv
blue:
mul
green:
relu
"""
tag_node = defaultdict(list)
gm: torch.fx.GraphModule = torch.export.export(
module, (inputs,), strict=True
).module()
# Add tag to all nodes and build dictionary record tag to call_module nodes
for node in gm.graph.nodes:
if "conv" in node.name:
node.tag = tags[0]
tag_node[tags[0]].append(node.name)
elif "mul" in node.name:
node.tag = tags[1]
tag_node[tags[1]].append(node.name)
else:
node.tag = tags[2]
if node.op == "call_module":
tag_node[tags[2]].append(node.name)
return gm, tag_node
def test_split_by_tags(self) -> None:
tags = ["red", "blue", "green"]
module = TestSplitOutputType.TestModule()
inputs = torch.randn((1, 3, 224, 224))
gm, _ = TestSplitOutputType.trace_and_tag(module, inputs, tags)
split_gm, _ = split_by_tags(gm, tags, return_fqn_mapping=True)
gm_output = module(inputs)
split_gm_output = split_gm(inputs)
self.assertTrue(type(gm_output) is type(split_gm_output))
self.assertTrue(torch.equal(gm_output, split_gm_output))
if __name__ == "__main__":
raise RuntimeError(
"This test is not currently used and should be "
"enabled in discover_tests.py if required."
)