Files
pytorch/test/inductor/test_minifier_utils.py
Shangdi Yu c05813d2a9 [AOTI Minifier] Exclude illegal graphs from minifier search (#140999)
Summary:
Some graphs produced by the minifier graph cutter cannot be used for AOTI/export (illegal graphs), these should be considered as graphs that don't fail in the minifier, so the minifier keeps searching.

One example is the following graph, where `true_graph_0` is an fx.GraphModule. Here, export.export() would give a `UserError` with `ErrorType = UserErrorType.INVALID_OUTPUT`.

```
      # graph():
        #     %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]
        #     return (true_graph_0,)
```

This graph could be obtained from the module below:

```python
    class M(torch.nn.Module):
        def forward(self, x, flag):
            flag = flag.item()

            def true_fn(x):
                return x.clone()

            return torch.cond(flag > 0, true_fn, true_fn, [x])
 ```

So we detect such errors, and exclude them from minifier's search (consider these graphs as didn't fail).

This is ok and won't miss any actual errors, since the AOTI minifier is only designed to catch errors in the AOTI phase anyway, it is not responsible to catching export bugs.

Test Plan:
```
buck2 run  fbcode//caffe2/test/inductor:test_minifier_utils  -- -r invalid_output
```

Differential Revision: D66143487

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140999
Approved by: https://github.com/henrylhtsang
2024-11-20 03:20:06 +00:00

28 lines
933 B
Python

# Owner(s): ["module: inductor"]
import torch
from torch._dynamo.repro.aoti import export_for_aoti_minifier
from torch.testing._internal.common_utils import run_tests, TestCase
class MinifierUtilsTests(TestCase):
def test_invalid_output(self):
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, x):
# return a graph module
return self.linear
model = SimpleModel()
# Here we obtained a graph with invalid output by symbolic_trace for simplicity,
# it can also obtained from running functorch.compile.minifier on an exported graph.
traced = torch.fx.symbolic_trace(model)
gm = export_for_aoti_minifier(traced, (torch.randn(2, 2),))
self.assertTrue(gm is None)
if __name__ == "__main__":
run_tests()