mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Some graphs produced by the minifier graph cutter cannot be used for AOTI/export (illegal graphs), these should be considered as graphs that don't fail in the minifier, so the minifier keeps searching. One example is the following graph, where `true_graph_0` is an fx.GraphModule. Here, export.export() would give a `UserError` with `ErrorType = UserErrorType.INVALID_OUTPUT`. ``` # graph(): # %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] # return (true_graph_0,) ``` This graph could be obtained from the module below: ```python class M(torch.nn.Module): def forward(self, x, flag): flag = flag.item() def true_fn(x): return x.clone() return torch.cond(flag > 0, true_fn, true_fn, [x]) ``` So we detect such errors, and exclude them from minifier's search (consider these graphs as didn't fail). This is ok and won't miss any actual errors, since the AOTI minifier is only designed to catch errors in the AOTI phase anyway, it is not responsible to catching export bugs. Test Plan: ``` buck2 run fbcode//caffe2/test/inductor:test_minifier_utils -- -r invalid_output ``` Differential Revision: D66143487 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140999 Approved by: https://github.com/henrylhtsang
28 lines
933 B
Python
28 lines
933 B
Python
# Owner(s): ["module: inductor"]
|
|
import torch
|
|
from torch._dynamo.repro.aoti import export_for_aoti_minifier
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class MinifierUtilsTests(TestCase):
|
|
def test_invalid_output(self):
|
|
class SimpleModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
# return a graph module
|
|
return self.linear
|
|
|
|
model = SimpleModel()
|
|
# Here we obtained a graph with invalid output by symbolic_trace for simplicity,
|
|
# it can also obtained from running functorch.compile.minifier on an exported graph.
|
|
traced = torch.fx.symbolic_trace(model)
|
|
gm = export_for_aoti_minifier(traced, (torch.randn(2, 2),))
|
|
self.assertTrue(gm is None)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|