Revert "kill allow_complex_guards_as_runtime_asserts (#160198)"

This reverts commit 69d91b94ba5366f4444d8cb8fd3dab4de4f04d3d.

Reverted https://github.com/pytorch/pytorch/pull/160198 on behalf of https://github.com/jeffdaily due to let's revert again instead of waiting for forward fix, see earlier comments ([comment](https://github.com/pytorch/pytorch/pull/160198#issuecomment-3235165462))
This commit is contained in:
PyTorch MergeBot
2025-08-28 22:50:37 +00:00
parent fffa62fa12
commit 47742081c9
9 changed files with 67 additions and 47 deletions

View File

@ -10867,8 +10867,8 @@ def ___make_guard_fn():
ShapeEnv not equal: field values don't match:
==> settings: values don't match.
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
""",
)
self._replay_and_check(main)

View File

@ -5609,11 +5609,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dim0_x = torch.export.Dim("dim0_x", min=3)
dim1_x = torch.export.Dim("dim1_x", max=8000)
dynamic_shapes = {"x": (dim0_x, dim1_x)}
em = torch.export.export(
em = torch.export._trace._export(
m,
(a,),
dynamic_shapes=dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
)
em.module()(torch.randn(4, 3))
with self.assertRaisesRegex(
@ -13497,7 +13497,7 @@ def forward(self, x, y):
def test_disable_forced_specializations_ok(self):
# check that we don't force specialization, and defer to runtime asserts
# with prefer_deferred_runtime_asserts_over_guards=True to successfully export
# with allow_complex_guards_as_runtime_asserts=True to successfully export
# case 1: modulo guards
from torch.export import dims
@ -13507,11 +13507,11 @@ def forward(self, x, y):
inputs = (torch.randn(10, 72),)
dx, dy = dims("dx", "dy")
ep = torch.export.export(
ep = torch.export._trace._export(
Mod4Reshape(),
inputs,
dynamic_shapes={"x": (dx, dy)},
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
)
out1 = ep.module()(torch.randn(8, 7))
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
@ -13541,11 +13541,11 @@ def forward(self, x, y):
for private_api in (True, False):
if private_api:
ep = torch.export.export(
ep = torch.export._trace._export(
FreeReshape(),
inputs,
dynamic_shapes=dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
)
else:
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
@ -13582,11 +13582,11 @@ def forward(self, x, y):
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
"y": (Dim("dy", min=8),),
}
ep = torch.export.export(
ep = torch.export._trace._export(
Reshape3d(),
inputs,
dynamic_shapes=dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
)
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
self.assertEqual(out1.shape, torch.ones(126).shape)
@ -13708,11 +13708,11 @@ def forward(self, x, y):
model = Model()
x = torch.rand(1024, 20, 16)
dynamic_shapes = {"x": {0: Dim("batch")}}
ep = torch.export.export(
ep = torch.export._trace._export(
model,
(x,),
dynamic_shapes=dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
)
with self.assertRaisesRegex(
RuntimeError,
@ -13785,11 +13785,11 @@ def forward(self, x, y):
inputs = (torch.randn(6), torch.randn(12))
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
ep = torch.export.export(
ep = torch.export._trace._export(
Foo(),
inputs,
dynamic_shapes=dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
)
# check forward pass
out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
@ -13824,7 +13824,7 @@ def forward(self, x, y):
Foo(),
inputs,
dynamic_shapes=dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
).run_decompositions()
self.assertEqual(
@ -14236,11 +14236,11 @@ graph():
inputs = (torch.randn(5), torch.randn(3))
shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
ep = torch.export.export(
ep = torch.export._trace._export(
Foo(),
inputs,
dynamic_shapes=shapes,
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
)
# count 2 pow nodes, 2 sym_size.int nodes
self.assertEqual(
@ -15039,11 +15039,11 @@ def forward(self, x):
for private_api in (True, False):
if private_api:
ep = torch.export.export(
ep = torch.export._trace._export(
ModConstraint(),
(torch.randn(3, 4),),
dynamic_shapes={"x": (dynamic, dynamic)},
prefer_deferred_runtime_asserts_over_guards=True,
allow_complex_guards_as_runtime_asserts=True,
)
else:
ep = export(
@ -15057,7 +15057,7 @@ def forward(self, x):
for node in ep.graph.nodes
].count(True)
if private_api:
self.assertEqual(num_asserts, 6)
self.assertEqual(num_asserts, 7)
with self.assertRaisesRegex(
RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",

View File

@ -258,6 +258,12 @@ capture_dynamic_output_shape_ops = (
# hybrid backed unbacked symints
prefer_deferred_runtime_asserts_over_guards = False
# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
# range constraints + dims + derived dims language, we raise constraint violation
# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
# and allows complex guards as runtime assertions in the graph.
allow_complex_guards_as_runtime_asserts = False
# By default, dynamo will treat all ints as backed SymInts, which means (1) it
# will wait to see the int change over multiple runs before generalizing and
# (2) it will still always 0/1 specialize an int. When true, this knob

View File

@ -1734,6 +1734,7 @@ def export(
same_signature: bool = True,
disable_constraint_solver: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False,
_log_export_usage: bool = True,
constraints: Optional[list[Constraint]] = None,
**extra_kwargs: Any,
@ -1960,6 +1961,7 @@ def export(
capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
),
_compiling_state_context(),
):

View File

@ -468,6 +468,7 @@ class OutputGraph(OutputGraphGuardsState):
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
co_fields=self.co_fields,
)

View File

@ -330,7 +330,7 @@ def make_fake_inputs(
args,
kwargs,
dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=False,
allow_complex_guards_as_runtime_asserts=False,
):
"""
Given an nn module, example inputs, and constraints, return a new fake mode,
@ -382,7 +382,8 @@ def make_fake_inputs(
shape_env=ShapeEnv(
tracked_fakes=[],
co_fields=co_fields,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
trace_asserts=True,
),
allow_non_fake_inputs=True,

View File

@ -158,7 +158,7 @@ def export_for_training(
dynamic_shapes,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
)
@ -282,7 +282,7 @@ def export(
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
)
except Exception as e:
draft_export_msg = (

View File

@ -750,7 +750,7 @@ def _export_to_torch_ir(
*,
preserve_module_call_signature: tuple[str, ...] = (),
disable_constraint_solver: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False,
restore_fqn: bool = True,
_log_export_usage: bool = True,
same_signature: bool = True,
@ -810,7 +810,10 @@ def _export_to_torch_ir(
assume_static_by_default=True,
tracing_mode="symbolic",
disable_constraint_solver=disable_constraint_solver,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
# currently the following 2 flags are tied together for export purposes,
# but untangle for sake of dynamo export api
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
_log_export_usage=_log_export_usage,
same_signature=same_signature,
)(
@ -1399,7 +1402,7 @@ def _strict_export(
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
preserve_module_call_signature: tuple[str, ...],
orig_in_spec: TreeSpec,
prefer_deferred_runtime_asserts_over_guards: bool,
allow_complex_guards_as_runtime_asserts: bool,
_to_aten_func: Callable,
) -> ExportArtifact:
"""
@ -1413,7 +1416,7 @@ def _strict_export(
dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature,
restore_fqn=False, # don't need to restore because we will do it later
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
_log_export_usage=False,
)
@ -1861,7 +1864,7 @@ def _non_strict_export(
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
preserve_module_call_signature: tuple[str, ...],
orig_in_spec: TreeSpec,
prefer_deferred_runtime_asserts_over_guards: bool,
allow_complex_guards_as_runtime_asserts: bool,
_to_aten_func: Callable,
) -> ExportArtifact:
"""
@ -1958,7 +1961,7 @@ def _non_strict_export(
args,
kwargs,
dynamic_shapes,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization
)
fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
@ -2076,7 +2079,7 @@ def _export_for_training(
*,
strict: bool = True,
preserve_module_call_signature: tuple[str, ...] = (),
prefer_deferred_runtime_asserts_over_guards: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False,
) -> ExportedProgram:
global _EXPORT_MODULE_HIERARCHY
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@ -2106,7 +2109,7 @@ def _export_for_training(
dynamic_shapes=dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature,
orig_in_spec=orig_in_spec,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
_to_aten_func=_export_to_aten_ir_make_fx,
)
@ -2177,7 +2180,7 @@ def _export(
strict: bool = True,
preserve_module_call_signature: tuple[str, ...] = (),
pre_dispatch: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False,
) -> ExportedProgram:
"""
Traces either an nn.Module's forward function or just a callable with PyTorch
@ -2208,7 +2211,7 @@ def _export(
preserve_module_call_signature: A list of submodule paths for which the original
calling conventions are preserved as metadata.
prefer_deferred_runtime_asserts_over_guards:
allow_complex_guards_as_runtime_asserts:
With the current dynamic shapes language for dims and derived dims, we can run into constraints
that are not expressible with the language. For example, flattening a matrix and adding to a vector,
both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
@ -2252,7 +2255,7 @@ def _export(
dynamic_shapes,
strict=strict,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
)
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
return ep
@ -2277,7 +2280,7 @@ def _export(
dynamic_shapes=dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature,
orig_in_spec=original_in_spec,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
_to_aten_func=functools.partial(
_export_to_aten_ir,
pre_dispatch=pre_dispatch,

View File

@ -3536,6 +3536,7 @@ class ShapeEnvSettings:
specialize_zero_one: bool
duck_shape: bool
prefer_deferred_runtime_asserts_over_guards: bool
allow_complex_guards_as_runtime_asserts: bool
trace_asserts: bool
@ -3673,6 +3674,10 @@ class ShapeEnv:
# in guards is helpful, since these guards in some sense are overly
# pedantic. See also https://github.com/pytorch/pytorch/issues/121749
prefer_deferred_runtime_asserts_over_guards: bool = False,
# When True, does not emit or raise constraint violation errors on
# implicit guards generated by ops, and defers to runtime assertions
# in the graph instead. For export.
allow_complex_guards_as_runtime_asserts: bool = False,
# XXX Add any new settings that could affect FakeTensor evaluation
# to: torch._subclasses.fake_tensor._ShapeEnvSettings
trace_asserts: bool = False,
@ -3689,6 +3694,7 @@ class ShapeEnv:
specialize_zero_one=specialize_zero_one,
duck_shape=duck_shape,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
trace_asserts=trace_asserts,
)
@ -3900,6 +3906,10 @@ class ShapeEnv:
def prefer_deferred_runtime_asserts_over_guards(self) -> bool:
return self.settings.prefer_deferred_runtime_asserts_over_guards
@property
def allow_complex_guards_as_runtime_asserts(self) -> bool:
return self.settings.allow_complex_guards_as_runtime_asserts
@contextmanager
def patch_source_specialization(
self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr]
@ -6649,7 +6659,7 @@ class ShapeEnv:
assert isinstance(a, sympy.Symbol)
if (
self.prefer_deferred_runtime_asserts_over_guards
self.allow_complex_guards_as_runtime_asserts
and not _is_supported_equivalence(tgt)
):
return # continuing leads to placeholder shapes having complex expressions that we can't resolve
@ -7631,15 +7641,7 @@ class ShapeEnv:
# is no longer necessary)
self._maybe_guard_rel(g)
if (
torch.compiler.is_exporting()
and self.prefer_deferred_runtime_asserts_over_guards
):
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
# and so the result here will be statically known
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
else:
if not self.allow_complex_guards_as_runtime_asserts:
# at this point, we've evaluated the concrete expr value, and have
# flipped/negated the guard if necessary. Now we know what to guard
# or defer to runtime assert on.
@ -7648,6 +7650,11 @@ class ShapeEnv:
)
self.guards.append(guard)
self.axioms.update(dict(self.get_implications(self.simplify(g))))
else:
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
# and so the result here will be statically known
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
else:
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)