kill allow_complex_guards_as_runtime_asserts (#160198)

Summary: Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept).

Test Plan:
updated tests

Rollback Plan:

Differential Revision: D79903317

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160198
Approved by: https://github.com/ezyang
This commit is contained in:
Avik Chaudhuri
2025-08-28 07:59:29 +00:00
committed by PyTorch MergeBot
parent fa76256603
commit 196232bb93
9 changed files with 47 additions and 67 deletions

View File

@ -10916,8 +10916,8 @@ def ___make_guard_fn():
ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match:
==> settings: values don't match. ==> settings: values don't match.
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False) > Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False) > Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
""", """,
) )
self._replay_and_check(main) self._replay_and_check(main)

View File

@ -5514,11 +5514,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dim0_x = torch.export.Dim("dim0_x", min=3) dim0_x = torch.export.Dim("dim0_x", min=3)
dim1_x = torch.export.Dim("dim1_x", max=8000) dim1_x = torch.export.Dim("dim1_x", max=8000)
dynamic_shapes = {"x": (dim0_x, dim1_x)} dynamic_shapes = {"x": (dim0_x, dim1_x)}
em = torch.export._trace._export( em = torch.export.export(
m, m,
(a,), (a,),
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
) )
em.module()(torch.randn(4, 3)) em.module()(torch.randn(4, 3))
with self.assertRaisesRegex( with self.assertRaisesRegex(
@ -13402,7 +13402,7 @@ def forward(self, x, y):
def test_disable_forced_specializations_ok(self): def test_disable_forced_specializations_ok(self):
# check that we don't force specialization, and defer to runtime asserts # check that we don't force specialization, and defer to runtime asserts
# with allow_complex_guards_as_runtime_asserts=True to successfully export # with prefer_deferred_runtime_asserts_over_guards=True to successfully export
# case 1: modulo guards # case 1: modulo guards
from torch.export import dims from torch.export import dims
@ -13412,11 +13412,11 @@ def forward(self, x, y):
inputs = (torch.randn(10, 72),) inputs = (torch.randn(10, 72),)
dx, dy = dims("dx", "dy") dx, dy = dims("dx", "dy")
ep = torch.export._trace._export( ep = torch.export.export(
Mod4Reshape(), Mod4Reshape(),
inputs, inputs,
dynamic_shapes={"x": (dx, dy)}, dynamic_shapes={"x": (dx, dy)},
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
) )
out1 = ep.module()(torch.randn(8, 7)) out1 = ep.module()(torch.randn(8, 7))
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape) self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
@ -13446,11 +13446,11 @@ def forward(self, x, y):
for private_api in (True, False): for private_api in (True, False):
if private_api: if private_api:
ep = torch.export._trace._export( ep = torch.export.export(
FreeReshape(), FreeReshape(),
inputs, inputs,
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
) )
else: else:
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes) ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
@ -13487,11 +13487,11 @@ def forward(self, x, y):
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)), "x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
"y": (Dim("dy", min=8),), "y": (Dim("dy", min=8),),
} }
ep = torch.export._trace._export( ep = torch.export.export(
Reshape3d(), Reshape3d(),
inputs, inputs,
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
) )
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126)) out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
self.assertEqual(out1.shape, torch.ones(126).shape) self.assertEqual(out1.shape, torch.ones(126).shape)
@ -13613,11 +13613,11 @@ def forward(self, x, y):
model = Model() model = Model()
x = torch.rand(1024, 20, 16) x = torch.rand(1024, 20, 16)
dynamic_shapes = {"x": {0: Dim("batch")}} dynamic_shapes = {"x": {0: Dim("batch")}}
ep = torch.export._trace._export( ep = torch.export.export(
model, model,
(x,), (x,),
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
) )
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
@ -13690,11 +13690,11 @@ def forward(self, x, y):
inputs = (torch.randn(6), torch.randn(12)) inputs = (torch.randn(6), torch.randn(12))
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]} dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
ep = torch.export._trace._export( ep = torch.export.export(
Foo(), Foo(),
inputs, inputs,
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
) )
# check forward pass # check forward pass
out0, out1 = ep.module()(torch.randn(9), torch.randn(27)) out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
@ -13729,7 +13729,7 @@ def forward(self, x, y):
Foo(), Foo(),
inputs, inputs,
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
).run_decompositions() ).run_decompositions()
self.assertEqual( self.assertEqual(
@ -14141,11 +14141,11 @@ graph():
inputs = (torch.randn(5), torch.randn(3)) inputs = (torch.randn(5), torch.randn(3))
shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)} shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
ep = torch.export._trace._export( ep = torch.export.export(
Foo(), Foo(),
inputs, inputs,
dynamic_shapes=shapes, dynamic_shapes=shapes,
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
) )
# count 2 pow nodes, 2 sym_size.int nodes # count 2 pow nodes, 2 sym_size.int nodes
self.assertEqual( self.assertEqual(
@ -14944,11 +14944,11 @@ def forward(self, x):
for private_api in (True, False): for private_api in (True, False):
if private_api: if private_api:
ep = torch.export._trace._export( ep = torch.export.export(
ModConstraint(), ModConstraint(),
(torch.randn(3, 4),), (torch.randn(3, 4),),
dynamic_shapes={"x": (dynamic, dynamic)}, dynamic_shapes={"x": (dynamic, dynamic)},
allow_complex_guards_as_runtime_asserts=True, prefer_deferred_runtime_asserts_over_guards=True,
) )
else: else:
ep = export( ep = export(
@ -14962,7 +14962,7 @@ def forward(self, x):
for node in ep.graph.nodes for node in ep.graph.nodes
].count(True) ].count(True)
if private_api: if private_api:
self.assertEqual(num_asserts, 7) self.assertEqual(num_asserts, 6)
with self.assertRaisesRegex( with self.assertRaisesRegex(
RuntimeError, RuntimeError,
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)", r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",

View File

@ -258,12 +258,6 @@ capture_dynamic_output_shape_ops = (
# hybrid backed unbacked symints # hybrid backed unbacked symints
prefer_deferred_runtime_asserts_over_guards = False prefer_deferred_runtime_asserts_over_guards = False
# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
# range constraints + dims + derived dims language, we raise constraint violation
# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
# and allows complex guards as runtime assertions in the graph.
allow_complex_guards_as_runtime_asserts = False
# By default, dynamo will treat all ints as backed SymInts, which means (1) it # By default, dynamo will treat all ints as backed SymInts, which means (1) it
# will wait to see the int change over multiple runs before generalizing and # will wait to see the int change over multiple runs before generalizing and
# (2) it will still always 0/1 specialize an int. When true, this knob # (2) it will still always 0/1 specialize an int. When true, this knob

View File

@ -1734,7 +1734,6 @@ def export(
same_signature: bool = True, same_signature: bool = True,
disable_constraint_solver: bool = False, disable_constraint_solver: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False,
_log_export_usage: bool = True, _log_export_usage: bool = True,
constraints: Optional[list[Constraint]] = None, constraints: Optional[list[Constraint]] = None,
**extra_kwargs: Any, **extra_kwargs: Any,
@ -1961,7 +1960,6 @@ def export(
capture_dynamic_output_shape_ops=True, capture_dynamic_output_shape_ops=True,
capture_scalar_outputs=True, capture_scalar_outputs=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
), ),
_compiling_state_context(), _compiling_state_context(),
): ):

View File

@ -468,7 +468,6 @@ class OutputGraph(OutputGraphGuardsState):
allow_scalar_outputs=config.capture_scalar_outputs, allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
co_fields=self.co_fields, co_fields=self.co_fields,
) )

View File

@ -330,7 +330,7 @@ def make_fake_inputs(
args, args,
kwargs, kwargs,
dynamic_shapes, dynamic_shapes,
allow_complex_guards_as_runtime_asserts=False, prefer_deferred_runtime_asserts_over_guards=False,
): ):
""" """
Given an nn module, example inputs, and constraints, return a new fake mode, Given an nn module, example inputs, and constraints, return a new fake mode,
@ -382,8 +382,7 @@ def make_fake_inputs(
shape_env=ShapeEnv( shape_env=ShapeEnv(
tracked_fakes=[], tracked_fakes=[],
co_fields=co_fields, co_fields=co_fields,
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
trace_asserts=True, trace_asserts=True,
), ),
allow_non_fake_inputs=True, allow_non_fake_inputs=True,

View File

@ -158,7 +158,7 @@ def export_for_training(
dynamic_shapes, dynamic_shapes,
strict=strict, strict=strict,
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
) )
@ -282,7 +282,7 @@ def export(
strict=strict, strict=strict,
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
pre_dispatch=True, pre_dispatch=True,
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
) )
except Exception as e: except Exception as e:
draft_export_msg = ( draft_export_msg = (

View File

@ -750,7 +750,7 @@ def _export_to_torch_ir(
*, *,
preserve_module_call_signature: tuple[str, ...] = (), preserve_module_call_signature: tuple[str, ...] = (),
disable_constraint_solver: bool = False, disable_constraint_solver: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
restore_fqn: bool = True, restore_fqn: bool = True,
_log_export_usage: bool = True, _log_export_usage: bool = True,
same_signature: bool = True, same_signature: bool = True,
@ -810,10 +810,7 @@ def _export_to_torch_ir(
assume_static_by_default=True, assume_static_by_default=True,
tracing_mode="symbolic", tracing_mode="symbolic",
disable_constraint_solver=disable_constraint_solver, disable_constraint_solver=disable_constraint_solver,
# currently the following 2 flags are tied together for export purposes, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
# but untangle for sake of dynamo export api
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
_log_export_usage=_log_export_usage, _log_export_usage=_log_export_usage,
same_signature=same_signature, same_signature=same_signature,
)( )(
@ -1402,7 +1399,7 @@ def _strict_export(
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]], dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
preserve_module_call_signature: tuple[str, ...], preserve_module_call_signature: tuple[str, ...],
orig_in_spec: TreeSpec, orig_in_spec: TreeSpec,
allow_complex_guards_as_runtime_asserts: bool, prefer_deferred_runtime_asserts_over_guards: bool,
_to_aten_func: Callable, _to_aten_func: Callable,
) -> ExportArtifact: ) -> ExportArtifact:
""" """
@ -1416,7 +1413,7 @@ def _strict_export(
dynamic_shapes, dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
restore_fqn=False, # don't need to restore because we will do it later restore_fqn=False, # don't need to restore because we will do it later
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_log_export_usage=False, _log_export_usage=False,
) )
@ -1859,7 +1856,7 @@ def _non_strict_export(
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]], dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
preserve_module_call_signature: tuple[str, ...], preserve_module_call_signature: tuple[str, ...],
orig_in_spec: TreeSpec, orig_in_spec: TreeSpec,
allow_complex_guards_as_runtime_asserts: bool, prefer_deferred_runtime_asserts_over_guards: bool,
_to_aten_func: Callable, _to_aten_func: Callable,
) -> ExportArtifact: ) -> ExportArtifact:
""" """
@ -1956,7 +1953,7 @@ def _non_strict_export(
args, args,
kwargs, kwargs,
dynamic_shapes, dynamic_shapes,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization
) )
fake_params_buffers = _fakify_params_buffers(fake_mode, mod) fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
@ -2037,7 +2034,7 @@ def _export_for_training(
*, *,
strict: bool = True, strict: bool = True,
preserve_module_call_signature: tuple[str, ...] = (), preserve_module_call_signature: tuple[str, ...] = (),
allow_complex_guards_as_runtime_asserts: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> ExportedProgram: ) -> ExportedProgram:
global _EXPORT_MODULE_HIERARCHY global _EXPORT_MODULE_HIERARCHY
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
@ -2062,7 +2059,7 @@ def _export_for_training(
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
orig_in_spec=orig_in_spec, orig_in_spec=orig_in_spec,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_to_aten_func=_export_to_aten_ir_make_fx, _to_aten_func=_export_to_aten_ir_make_fx,
) )
@ -2124,7 +2121,7 @@ def _export(
strict: bool = True, strict: bool = True,
preserve_module_call_signature: tuple[str, ...] = (), preserve_module_call_signature: tuple[str, ...] = (),
pre_dispatch: bool = False, pre_dispatch: bool = False,
allow_complex_guards_as_runtime_asserts: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> ExportedProgram: ) -> ExportedProgram:
""" """
Traces either an nn.Module's forward function or just a callable with PyTorch Traces either an nn.Module's forward function or just a callable with PyTorch
@ -2155,7 +2152,7 @@ def _export(
preserve_module_call_signature: A list of submodule paths for which the original preserve_module_call_signature: A list of submodule paths for which the original
calling conventions are preserved as metadata. calling conventions are preserved as metadata.
allow_complex_guards_as_runtime_asserts: prefer_deferred_runtime_asserts_over_guards:
With the current dynamic shapes language for dims and derived dims, we can run into constraints With the current dynamic shapes language for dims and derived dims, we can run into constraints
that are not expressible with the language. For example, flattening a matrix and adding to a vector, that are not expressible with the language. For example, flattening a matrix and adding to a vector,
both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible. both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
@ -2199,7 +2196,7 @@ def _export(
dynamic_shapes, dynamic_shapes,
strict=strict, strict=strict,
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
) )
dtrace_structured("exported_program", payload_fn=lambda: str(ep)) dtrace_structured("exported_program", payload_fn=lambda: str(ep))
return ep return ep
@ -2224,7 +2221,7 @@ def _export(
dynamic_shapes=dynamic_shapes, dynamic_shapes=dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature, preserve_module_call_signature=preserve_module_call_signature,
orig_in_spec=original_in_spec, orig_in_spec=original_in_spec,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
_to_aten_func=functools.partial( _to_aten_func=functools.partial(
_export_to_aten_ir, _export_to_aten_ir,
pre_dispatch=pre_dispatch, pre_dispatch=pre_dispatch,

View File

@ -3536,7 +3536,6 @@ class ShapeEnvSettings:
specialize_zero_one: bool specialize_zero_one: bool
duck_shape: bool duck_shape: bool
prefer_deferred_runtime_asserts_over_guards: bool prefer_deferred_runtime_asserts_over_guards: bool
allow_complex_guards_as_runtime_asserts: bool
trace_asserts: bool trace_asserts: bool
@ -3674,10 +3673,6 @@ class ShapeEnv:
# in guards is helpful, since these guards in some sense are overly # in guards is helpful, since these guards in some sense are overly
# pedantic. See also https://github.com/pytorch/pytorch/issues/121749 # pedantic. See also https://github.com/pytorch/pytorch/issues/121749
prefer_deferred_runtime_asserts_over_guards: bool = False, prefer_deferred_runtime_asserts_over_guards: bool = False,
# When True, does not emit or raise constraint violation errors on
# implicit guards generated by ops, and defers to runtime assertions
# in the graph instead. For export.
allow_complex_guards_as_runtime_asserts: bool = False,
# XXX Add any new settings that could affect FakeTensor evaluation # XXX Add any new settings that could affect FakeTensor evaluation
# to: torch._subclasses.fake_tensor._ShapeEnvSettings # to: torch._subclasses.fake_tensor._ShapeEnvSettings
trace_asserts: bool = False, trace_asserts: bool = False,
@ -3694,7 +3689,6 @@ class ShapeEnv:
specialize_zero_one=specialize_zero_one, specialize_zero_one=specialize_zero_one,
duck_shape=duck_shape, duck_shape=duck_shape,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
trace_asserts=trace_asserts, trace_asserts=trace_asserts,
) )
@ -3906,10 +3900,6 @@ class ShapeEnv:
def prefer_deferred_runtime_asserts_over_guards(self) -> bool: def prefer_deferred_runtime_asserts_over_guards(self) -> bool:
return self.settings.prefer_deferred_runtime_asserts_over_guards return self.settings.prefer_deferred_runtime_asserts_over_guards
@property
def allow_complex_guards_as_runtime_asserts(self) -> bool:
return self.settings.allow_complex_guards_as_runtime_asserts
@contextmanager @contextmanager
def patch_source_specialization( def patch_source_specialization(
self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr] self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr]
@ -6659,7 +6649,7 @@ class ShapeEnv:
assert isinstance(a, sympy.Symbol) assert isinstance(a, sympy.Symbol)
if ( if (
self.allow_complex_guards_as_runtime_asserts self.prefer_deferred_runtime_asserts_over_guards
and not _is_supported_equivalence(tgt) and not _is_supported_equivalence(tgt)
): ):
return # continuing leads to placeholder shapes having complex expressions that we can't resolve return # continuing leads to placeholder shapes having complex expressions that we can't resolve
@ -7641,7 +7631,15 @@ class ShapeEnv:
# is no longer necessary) # is no longer necessary)
self._maybe_guard_rel(g) self._maybe_guard_rel(g)
if not self.allow_complex_guards_as_runtime_asserts: if (
torch.compiler.is_exporting()
and self.prefer_deferred_runtime_asserts_over_guards
):
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
# and so the result here will be statically known
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
else:
# at this point, we've evaluated the concrete expr value, and have # at this point, we've evaluated the concrete expr value, and have
# flipped/negated the guard if necessary. Now we know what to guard # flipped/negated the guard if necessary. Now we know what to guard
# or defer to runtime assert on. # or defer to runtime assert on.
@ -7650,11 +7648,6 @@ class ShapeEnv:
) )
self.guards.append(guard) self.guards.append(guard)
self.axioms.update(dict(self.get_implications(self.simplify(g)))) self.axioms.update(dict(self.get_implications(self.simplify(g))))
else:
# it's fine to defer simple guards here without checking,
# the _maybe_guard_rel() call above will set replacements if possible,
# and so the result here will be statically known
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
else: else:
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)