switch prefer_deferred_runtime_asserts_over_guards in export (#160111)

Summary:
In preparation for checking shape guards in export, this PR effectively switches `prefer_deferred_runtime_asserts_over_guards` to `False`, matching Dynamo.

Actually that's a lie: we switch it to `allow_complex_guards_as_runtime_asserts`, which is `False` by default but can be controlled via an internally API to be `True`. This makes the two flags synchronized, so we should be able to kill `allow_complex_guards_as_runtime_asserts` at this point.

Test Plan:
updated tests

Rollback Plan:

Differential Revision: D79734206

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160111
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Avik Chaudhuri
2025-08-27 22:51:08 +00:00
committed by PyTorch MergeBot
parent 6b051d7de3
commit 12c0cf3fab
7 changed files with 86 additions and 42 deletions

View File

@ -318,10 +318,7 @@ def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
linear_weight = self.linear.weight
linear_bias = self.linear.bias
sym_size_int_2 = torch.ops.aten.sym_size.int(x, 1)
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
eq = sym_size_int_2 == 4; sym_size_int_2 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s27, 4) on node 'eq'"); eq = _assert_scalar_default = None
return pytree.tree_unflatten((linear,), self._out_spec)""",
)

View File

@ -11381,7 +11381,6 @@ graph():
ep.module()(4, torch.randn(4, 4))
@testing.expectedFailureCppRuntime
@testing.expectedFailureRetraceabilityNonStrict # no runtime asserts added for assert x == 3
def test_symint_input_ranges(self):
class M(torch.nn.Module):
def forward(self, x, y):
@ -11415,8 +11414,12 @@ graph():
)
constraints = list(ep.range_constraints.values())
constraint = constraints[0]
self.assertEqual(constraint.lower, 4)
self.assertEqual(constraint.upper, 5)
# retracebility does not remember the range asserts in the forward
lower, upper = (
(3, 10) if is_retracebility_test(self._testMethodName) else (4, 5)
)
self.assertEqual(constraint.lower, lower)
self.assertEqual(constraint.upper, upper)
# While tracing the range was found to be bigger than the original range
class M(torch.nn.Module):
@ -13440,22 +13443,36 @@ def forward(self, x, y):
"y": [Dim(f"dy{i}", min=2) for i in range(2)],
"z": [Dim(f"dz{i}", min=4) for i in range(1)],
}
ep = torch.export._trace._export(
FreeReshape(),
inputs,
dynamic_shapes=dynamic_shapes,
allow_complex_guards_as_runtime_asserts=True,
)
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48))
self.assertEqual(out1.shape, torch.ones(48).shape)
out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40))
self.assertEqual(out2.shape, torch.ones(40).shape)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
): # fail only at runtime
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail
for private_api in (True, False):
if private_api:
ep = torch.export._trace._export(
FreeReshape(),
inputs,
dynamic_shapes=dynamic_shapes,
allow_complex_guards_as_runtime_asserts=True,
)
else:
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48))
self.assertEqual(out1.shape, torch.ones(48).shape)
out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40))
self.assertEqual(out2.shape, torch.ones(40).shape)
if private_api:
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
): # fail only at runtime
ep.module()(
torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)
) # fail
else:
# no runtime assert in exported module but it fails anyway with wrong inputs
with self.assertRaisesRegex(
RuntimeError,
r"The size of tensor a \(40\) must match the size of tensor b \(20\) at non-singleton dimension 0",
):
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30))
# case 3: 3d reshape (previously failing with different issue)
class Reshape3d(torch.nn.Module):
@ -14925,21 +14942,41 @@ def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.view(x.shape[0] - 1, -1)
ep = export(
ModConstraint(),
(torch.randn(3, 4),),
dynamic_shapes={
"x": (dynamic, dynamic),
},
)
ep.module()(torch.randn(5, 8))
num_asserts = [
node.target == torch.ops.aten._assert_scalar.default
for node in ep.graph.nodes
].count(True)
self.assertEqual(num_asserts, 2)
with self.assertRaises(RuntimeError):
ep.module()(torch.randn(4, 2))
for private_api in (True, False):
if private_api:
ep = torch.export._trace._export(
ModConstraint(),
(torch.randn(3, 4),),
dynamic_shapes={"x": (dynamic, dynamic)},
allow_complex_guards_as_runtime_asserts=True,
)
else:
ep = export(
ModConstraint(),
(torch.randn(3, 4),),
dynamic_shapes={"x": (dynamic, dynamic)},
)
ep.module()(torch.randn(5, 8))
num_asserts = [
node.target == torch.ops.aten._assert_scalar.default
for node in ep.graph.nodes
].count(True)
if private_api:
self.assertEqual(num_asserts, 7)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",
):
ep.module()(torch.randn(4, 2))
else:
# no runtime assert in exported module
self.assertEqual(num_asserts, 0)
# but it fails anyway with wrong inputs
with self.assertRaisesRegex(
RuntimeError,
r"shape '\[3, -1\]' is invalid for input of size 8",
):
ep.module()(torch.randn(4, 2))
@testing.expectedFailureSerDer # T195866111
@testing.expectedFailureSerDerNonStrict

View File

@ -933,7 +933,7 @@ def forward(self, x):
fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count(
torch.ops.aten.sym_size.int
)
self.assertEqual(fn_count_sym_size(unflat.graph), 3)
self.assertEqual(fn_count_sym_size(unflat.graph), 1)
self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1)
self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0)

View File

@ -159,7 +159,11 @@ class AOTIRunnerUtil:
with torch.no_grad():
# strict=False needs extra migration work
ep = torch.export.export(
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
model,
example_inputs,
dynamic_shapes=dynamic_shapes,
strict=True,
prefer_deferred_runtime_asserts_over_guards=True,
)
package_path = torch._inductor.aoti_compile_and_package(
ep, inductor_configs=inductor_configs

View File

@ -382,7 +382,7 @@ def make_fake_inputs(
shape_env=ShapeEnv(
tracked_fakes=[],
co_fields=co_fields,
prefer_deferred_runtime_asserts_over_guards=True,
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
trace_asserts=True,
),

View File

@ -69,6 +69,7 @@ def export_for_training(
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> ExportedProgram:
"""
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -157,6 +158,7 @@ def export_for_training(
dynamic_shapes,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
)
@ -168,6 +170,7 @@ def export(
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> ExportedProgram:
"""
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
@ -279,6 +282,7 @@ def export(
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True,
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
)
except Exception as e:
draft_export_msg = (

View File

@ -812,7 +812,7 @@ def _export_to_torch_ir(
disable_constraint_solver=disable_constraint_solver,
# currently the following 2 flags are tied together for export purposes,
# but untangle for sake of dynamo export api
prefer_deferred_runtime_asserts_over_guards=True,
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
_log_export_usage=_log_export_usage,
same_signature=same_signature,
@ -2037,6 +2037,7 @@ def _export_for_training(
*,
strict: bool = True,
preserve_module_call_signature: tuple[str, ...] = (),
allow_complex_guards_as_runtime_asserts: bool = False,
) -> ExportedProgram:
global _EXPORT_MODULE_HIERARCHY
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@ -2061,7 +2062,7 @@ def _export_for_training(
dynamic_shapes=dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature,
orig_in_spec=orig_in_spec,
allow_complex_guards_as_runtime_asserts=False,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
_to_aten_func=_export_to_aten_ir_make_fx,
)
@ -2198,6 +2199,7 @@ def _export(
dynamic_shapes,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
)
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
return ep