mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: This PR introduces shape guards to export. Previously only value ranges, equalities, and specializations would be tracked for symbolic expressions, and we had a forward hook to check them. Instead now we create a function to check shape guards and call it in the exported program. Test Plan: updated several tests Rollback Plan: Differential Revision: D80713603 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161178 Approved by: https://github.com/tugsbayasgalan
96 lines
3.2 KiB
Python
96 lines
3.2 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import torch
|
|
from torch._dynamo.repro.aoti import (
|
|
AOTIMinifierError,
|
|
export_for_aoti_minifier,
|
|
get_module_string,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
class MinifierUtilsTests(TestCase):
|
|
def test_invalid_output(self):
|
|
class SimpleModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
# return a graph module
|
|
return self.linear
|
|
|
|
model = SimpleModel()
|
|
# Here we obtained a graph with invalid output by symbolic_trace for simplicity,
|
|
# it can also obtained from running functorch.compile.minifier on an exported graph.
|
|
traced = torch.fx.symbolic_trace(model)
|
|
for strict in [True, False]:
|
|
gm = export_for_aoti_minifier(traced, (torch.randn(2, 2),), strict=strict)
|
|
self.assertTrue(gm is None)
|
|
|
|
def test_non_exportable(self):
|
|
class SimpleModel(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.sum()
|
|
|
|
model = SimpleModel()
|
|
# Force export failure by providing an input with in-compatible shapes
|
|
inputs = (torch.randn(2), torch.randn(2))
|
|
for strict in [True, False]:
|
|
gm = export_for_aoti_minifier(
|
|
model, inputs, strict=strict, skip_export_error=True
|
|
)
|
|
print(gm)
|
|
self.assertTrue(gm is None)
|
|
|
|
with self.assertRaises(AOTIMinifierError):
|
|
export_for_aoti_minifier(
|
|
model, inputs, strict=strict, skip_export_error=False
|
|
)
|
|
|
|
def test_convert_module_to_string(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, flag):
|
|
flag = flag.item()
|
|
|
|
def true_fn(x):
|
|
return x.clone()
|
|
|
|
return torch.cond(flag > 0, true_fn, true_fn, [x])
|
|
|
|
inputs = (
|
|
torch.rand(28, 28),
|
|
torch.tensor(1),
|
|
)
|
|
|
|
model = M()
|
|
gm = torch.export.export(model, inputs, strict=False).module(check_guards=False)
|
|
|
|
# TODO: make NNModuleToString.convert() generate string for nested submodules.
|
|
model_string = get_module_string(gm)
|
|
self.assertExpectedInline(
|
|
model_string.strip(),
|
|
"""\
|
|
# from torch.nn import *
|
|
# class Repro(torch.nn.Module):
|
|
# def __init__(self) -> None:
|
|
# super().__init__()
|
|
# self.true_graph_0 = <lambda>()
|
|
# self.false_graph_0 = <lambda>()
|
|
|
|
|
|
|
|
# def forward(self, x, flag):
|
|
# x, flag, = fx_pytree.tree_flatten_spec(([x, flag], {}), self._in_spec)
|
|
# item = torch.ops.aten.item.default(flag); flag = None
|
|
# gt = item > 0; item = None
|
|
# true_graph_0 = self.true_graph_0
|
|
# false_graph_0 = self.false_graph_0
|
|
# cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x,)); gt = true_graph_0 = false_graph_0 = x = None
|
|
# getitem = cond[0]; cond = None
|
|
# return pytree.tree_unflatten((getitem,), self._out_spec)""",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|