Compare commits

...

1 Commits

Author SHA1 Message Date
cd2e4131a9 [WIP][export] runtime asserts on subgraphs (#158391)
Summary:



Rollback Plan:

Differential Revision: D78375400
2025-07-16 11:05:37 -07:00
5 changed files with 43 additions and 14 deletions

View File

@ -7430,6 +7430,31 @@ def forward(self, x):
ep = export(Simple(), example_inputs)
self.assertEqual(ep.module()(*example_inputs), Simple()(*example_inputs))
@testing.expectedFailureCppRuntime
def test_while_loop_index_assertions(self):
from torch._higher_order_ops import while_loop
class Foo(torch.nn.Module):
def forward(self, x):
def cond_fn(idx, acc):
i = idx.item()
torch._check(i >= 0)
return i < x.size(0)
def body_fn(idx, acc):
i = idx.item()
torch._check_is_size(i, max=x.size(0) - 1)
return idx + 1, acc + x[i]
acc = torch.zeros(x.size(1))
n = torch.full((), 0, dtype=torch.int64)
_, out = while_loop(cond_fn, body_fn, [n, acc])
return out
x = torch.randn(8, 4)
ep = export(Foo(), (x,))
self.assertTrue(torch.allclose(x.sum(dim=0), ep.module()(x)))
def test_constrain_size_with_various_cases(self):
class Module1(torch.nn.Module):
def forward(self, x, y):

View File

@ -7717,8 +7717,9 @@ class GraphModule(torch.nn.Module):
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"):
sym_size_int: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
lt: "Sym(u0 < s77)" = it_1 < sym_size_int; it_1 = sym_size_int = None
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x_1, 0); x_1 = None
lt: "Sym(u0 < s77)" = it_1 < sym_size_int_1; it_1 = sym_size_int_1 = None
return lt
class while_loop_body_graph_0(torch.nn.Module):

View File

@ -682,12 +682,13 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature):
):
shape_env = _get_shape_env_from_gm(gm)
if shape_env:
insert_deferred_runtime_asserts(
gm,
shape_env,
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
export=True,
)
for _gm in gm.modules():
insert_deferred_runtime_asserts(
_gm,
shape_env,
f"exported program: {first_call_function_nn_module_stack(gm.graph)}",
export=True,
)
# insert runtime assertions for aten.to nodes
_insert_aten_to_metadata_assert_pass(gm)

View File

@ -294,10 +294,10 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
carried_inputs,
)
cond_graph = reenter_make_fx(cond_fn)(
body_graph = reenter_make_fx(body_fn)(
*unspecialized_carried_inputs, *additional_inputs
)
body_graph = reenter_make_fx(body_fn)(
cond_graph = reenter_make_fx(cond_fn)(
*unspecialized_carried_inputs, *additional_inputs
)

View File

@ -37,8 +37,10 @@ def remove_assertion_nodes(graph_module: torch.fx.GraphModule) -> torch.fx.Graph
torch.ops.aten._assert_scalar.default,
torch.ops.aten._assert_tensor_metadata.default,
}
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in aten_assertion_targets:
graph_module.graph.erase_node(node)
graph_module.recompile()
for gm in graph_module.modules():
if isinstance(gm, torch.fx.GraphModule):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target in aten_assertion_targets:
gm.graph.erase_node(node)
gm.recompile()
return graph_module