mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
kill allow_complex_guards_as_runtime_asserts (#161794)
Summary: [reland] Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept). Test Plan: updated tests Rollback Plan: Differential Revision: D81334984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161794 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
aad96a2022
commit
3c45af079a
@ -263,7 +263,7 @@ class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
|
||||
dynamic_shapes=None,
|
||||
preserve_module_call_signature=(),
|
||||
restore_fqn=False,
|
||||
allow_complex_guards_as_runtime_asserts=False,
|
||||
prefer_deferred_runtime_asserts_over_guards=False,
|
||||
_log_export_usage=False,
|
||||
)
|
||||
# NOTE: this is necessary for rng to be added to the exported graph
|
||||
|
@ -10903,8 +10903,8 @@ def ___make_guard_fn():
|
||||
ShapeEnv not equal: field values don't match:
|
||||
|
||||
==> settings: values don't match.
|
||||
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
|
||||
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
|
||||
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
|
||||
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
|
||||
""",
|
||||
)
|
||||
self._replay_and_check(main)
|
||||
|
@ -5693,11 +5693,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
dim0_x = torch.export.Dim("dim0_x", min=3)
|
||||
dim1_x = torch.export.Dim("dim1_x", max=8000)
|
||||
dynamic_shapes = {"x": (dim0_x, dim1_x)}
|
||||
em = torch.export._trace._export(
|
||||
em = torch.export.export(
|
||||
m,
|
||||
(a,),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
em.module()(torch.randn(4, 3))
|
||||
with self.assertRaisesRegex(
|
||||
@ -13581,7 +13581,7 @@ def forward(self, x, y):
|
||||
|
||||
def test_disable_forced_specializations_ok(self):
|
||||
# check that we don't force specialization, and defer to runtime asserts
|
||||
# with allow_complex_guards_as_runtime_asserts=True to successfully export
|
||||
# with prefer_deferred_runtime_asserts_over_guards=True to successfully export
|
||||
# case 1: modulo guards
|
||||
from torch.export import dims
|
||||
|
||||
@ -13591,11 +13591,11 @@ def forward(self, x, y):
|
||||
|
||||
inputs = (torch.randn(10, 72),)
|
||||
dx, dy = dims("dx", "dy")
|
||||
ep = torch.export._trace._export(
|
||||
ep = torch.export.export(
|
||||
Mod4Reshape(),
|
||||
inputs,
|
||||
dynamic_shapes={"x": (dx, dy)},
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
out1 = ep.module()(torch.randn(8, 7))
|
||||
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
|
||||
@ -13625,11 +13625,11 @@ def forward(self, x, y):
|
||||
|
||||
for private_api in (True, False):
|
||||
if private_api:
|
||||
ep = torch.export._trace._export(
|
||||
ep = torch.export.export(
|
||||
FreeReshape(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
else:
|
||||
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
@ -13666,11 +13666,11 @@ def forward(self, x, y):
|
||||
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
|
||||
"y": (Dim("dy", min=8),),
|
||||
}
|
||||
ep = torch.export._trace._export(
|
||||
ep = torch.export.export(
|
||||
Reshape3d(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
|
||||
self.assertEqual(out1.shape, torch.ones(126).shape)
|
||||
@ -13792,11 +13792,11 @@ def forward(self, x, y):
|
||||
model = Model()
|
||||
x = torch.rand(1024, 20, 16)
|
||||
dynamic_shapes = {"x": {0: Dim("batch")}}
|
||||
ep = torch.export._trace._export(
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
(x,),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
@ -13869,11 +13869,11 @@ def forward(self, x, y):
|
||||
|
||||
inputs = (torch.randn(6), torch.randn(12))
|
||||
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
|
||||
ep = torch.export._trace._export(
|
||||
ep = torch.export.export(
|
||||
Foo(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
# check forward pass
|
||||
out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
|
||||
@ -13908,7 +13908,7 @@ def forward(self, x, y):
|
||||
Foo(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
).run_decompositions()
|
||||
|
||||
self.assertEqual(
|
||||
@ -14320,11 +14320,11 @@ graph():
|
||||
|
||||
inputs = (torch.randn(5), torch.randn(3))
|
||||
shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
|
||||
ep = torch.export._trace._export(
|
||||
ep = torch.export.export(
|
||||
Foo(),
|
||||
inputs,
|
||||
dynamic_shapes=shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
# count 2 pow nodes, 2 sym_size.int nodes
|
||||
self.assertEqual(
|
||||
@ -15123,11 +15123,11 @@ def forward(self, x):
|
||||
|
||||
for private_api in (True, False):
|
||||
if private_api:
|
||||
ep = torch.export._trace._export(
|
||||
ep = torch.export.export(
|
||||
ModConstraint(),
|
||||
(torch.randn(3, 4),),
|
||||
dynamic_shapes={"x": (dynamic, dynamic)},
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
else:
|
||||
ep = export(
|
||||
@ -15141,7 +15141,7 @@ def forward(self, x):
|
||||
for node in ep.graph.nodes
|
||||
].count(True)
|
||||
if private_api:
|
||||
self.assertEqual(num_asserts, 7)
|
||||
self.assertEqual(num_asserts, 6)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",
|
||||
|
@ -258,12 +258,6 @@ capture_dynamic_output_shape_ops = (
|
||||
# hybrid backed unbacked symints
|
||||
prefer_deferred_runtime_asserts_over_guards = False
|
||||
|
||||
# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
|
||||
# range constraints + dims + derived dims language, we raise constraint violation
|
||||
# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
|
||||
# and allows complex guards as runtime assertions in the graph.
|
||||
allow_complex_guards_as_runtime_asserts = False
|
||||
|
||||
# By default, dynamo will treat all ints as backed SymInts, which means (1) it
|
||||
# will wait to see the int change over multiple runs before generalizing and
|
||||
# (2) it will still always 0/1 specialize an int. When true, this knob
|
||||
|
@ -1734,7 +1734,6 @@ def export(
|
||||
same_signature: bool = True,
|
||||
disable_constraint_solver: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
_log_export_usage: bool = True,
|
||||
constraints: Optional[list[Constraint]] = None,
|
||||
**extra_kwargs: Any,
|
||||
@ -1961,7 +1960,6 @@ def export(
|
||||
capture_dynamic_output_shape_ops=True,
|
||||
capture_scalar_outputs=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
),
|
||||
_compiling_state_context(),
|
||||
):
|
||||
|
@ -468,7 +468,6 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
allow_scalar_outputs=config.capture_scalar_outputs,
|
||||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
|
||||
allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
|
||||
co_fields=self.co_fields,
|
||||
)
|
||||
|
||||
|
@ -330,7 +330,7 @@ def make_fake_inputs(
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=False,
|
||||
prefer_deferred_runtime_asserts_over_guards=False,
|
||||
):
|
||||
"""
|
||||
Given an nn module, example inputs, and constraints, return a new fake mode,
|
||||
@ -382,8 +382,7 @@ def make_fake_inputs(
|
||||
shape_env=ShapeEnv(
|
||||
tracked_fakes=[],
|
||||
co_fields=co_fields,
|
||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
trace_asserts=True,
|
||||
),
|
||||
allow_non_fake_inputs=True,
|
||||
|
@ -158,7 +158,7 @@ def export_for_training(
|
||||
dynamic_shapes,
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
|
||||
|
||||
@ -282,7 +282,7 @@ def export(
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
pre_dispatch=True,
|
||||
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
except Exception as e:
|
||||
draft_export_msg = (
|
||||
|
@ -756,7 +756,7 @@ def _export_to_torch_ir(
|
||||
*,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
disable_constraint_solver: bool = False,
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
restore_fqn: bool = True,
|
||||
_log_export_usage: bool = True,
|
||||
same_signature: bool = True,
|
||||
@ -816,10 +816,7 @@ def _export_to_torch_ir(
|
||||
assume_static_by_default=True,
|
||||
tracing_mode="symbolic",
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
# currently the following 2 flags are tied together for export purposes,
|
||||
# but untangle for sake of dynamo export api
|
||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_log_export_usage=_log_export_usage,
|
||||
same_signature=same_signature,
|
||||
)(
|
||||
@ -1408,7 +1405,7 @@ def _strict_export(
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
||||
preserve_module_call_signature: tuple[str, ...],
|
||||
orig_in_spec: TreeSpec,
|
||||
allow_complex_guards_as_runtime_asserts: bool,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool,
|
||||
_to_aten_func: Callable,
|
||||
) -> ExportArtifact:
|
||||
"""
|
||||
@ -1422,7 +1419,7 @@ def _strict_export(
|
||||
dynamic_shapes,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
restore_fqn=False, # don't need to restore because we will do it later
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_log_export_usage=False,
|
||||
)
|
||||
|
||||
@ -1865,7 +1862,7 @@ def _non_strict_export(
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
||||
preserve_module_call_signature: tuple[str, ...],
|
||||
orig_in_spec: TreeSpec,
|
||||
allow_complex_guards_as_runtime_asserts: bool,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool,
|
||||
_to_aten_func: Callable,
|
||||
) -> ExportArtifact:
|
||||
"""
|
||||
@ -1962,7 +1959,7 @@ def _non_strict_export(
|
||||
args,
|
||||
kwargs,
|
||||
dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization
|
||||
)
|
||||
|
||||
fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
|
||||
@ -2043,7 +2040,7 @@ def _export_for_training(
|
||||
*,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
) -> ExportedProgram:
|
||||
global _EXPORT_MODULE_HIERARCHY
|
||||
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
||||
@ -2078,7 +2075,7 @@ def _export_for_training(
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
orig_in_spec=orig_in_spec,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_to_aten_func=_export_to_aten_ir_make_fx,
|
||||
)
|
||||
|
||||
@ -2187,7 +2184,7 @@ def _export(
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
pre_dispatch: bool = False,
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
Traces either an nn.Module's forward function or just a callable with PyTorch
|
||||
@ -2218,7 +2215,7 @@ def _export(
|
||||
preserve_module_call_signature: A list of submodule paths for which the original
|
||||
calling conventions are preserved as metadata.
|
||||
|
||||
allow_complex_guards_as_runtime_asserts:
|
||||
prefer_deferred_runtime_asserts_over_guards:
|
||||
With the current dynamic shapes language for dims and derived dims, we can run into constraints
|
||||
that are not expressible with the language. For example, flattening a matrix and adding to a vector,
|
||||
both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
|
||||
@ -2262,7 +2259,7 @@ def _export(
|
||||
dynamic_shapes,
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
||||
return ep
|
||||
@ -2287,7 +2284,7 @@ def _export(
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
orig_in_spec=original_in_spec,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
_to_aten_func=functools.partial(
|
||||
_export_to_aten_ir,
|
||||
pre_dispatch=pre_dispatch,
|
||||
|
@ -3536,7 +3536,6 @@ class ShapeEnvSettings:
|
||||
specialize_zero_one: bool
|
||||
duck_shape: bool
|
||||
prefer_deferred_runtime_asserts_over_guards: bool
|
||||
allow_complex_guards_as_runtime_asserts: bool
|
||||
trace_asserts: bool
|
||||
|
||||
|
||||
@ -3674,10 +3673,6 @@ class ShapeEnv:
|
||||
# in guards is helpful, since these guards in some sense are overly
|
||||
# pedantic. See also https://github.com/pytorch/pytorch/issues/121749
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
# When True, does not emit or raise constraint violation errors on
|
||||
# implicit guards generated by ops, and defers to runtime assertions
|
||||
# in the graph instead. For export.
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
# XXX Add any new settings that could affect FakeTensor evaluation
|
||||
# to: torch._subclasses.fake_tensor._ShapeEnvSettings
|
||||
trace_asserts: bool = False,
|
||||
@ -3694,7 +3689,6 @@ class ShapeEnv:
|
||||
specialize_zero_one=specialize_zero_one,
|
||||
duck_shape=duck_shape,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
trace_asserts=trace_asserts,
|
||||
)
|
||||
|
||||
@ -3906,10 +3900,6 @@ class ShapeEnv:
|
||||
def prefer_deferred_runtime_asserts_over_guards(self) -> bool:
|
||||
return self.settings.prefer_deferred_runtime_asserts_over_guards
|
||||
|
||||
@property
|
||||
def allow_complex_guards_as_runtime_asserts(self) -> bool:
|
||||
return self.settings.allow_complex_guards_as_runtime_asserts
|
||||
|
||||
@contextmanager
|
||||
def patch_source_specialization(
|
||||
self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr]
|
||||
@ -6658,7 +6648,7 @@ class ShapeEnv:
|
||||
assert isinstance(a, sympy.Symbol)
|
||||
|
||||
if (
|
||||
self.allow_complex_guards_as_runtime_asserts
|
||||
self.prefer_deferred_runtime_asserts_over_guards
|
||||
and not _is_supported_equivalence(tgt)
|
||||
):
|
||||
return # continuing leads to placeholder shapes having complex expressions that we can't resolve
|
||||
@ -7633,7 +7623,15 @@ class ShapeEnv:
|
||||
# is no longer necessary)
|
||||
self._maybe_guard_rel(g)
|
||||
|
||||
if not self.allow_complex_guards_as_runtime_asserts:
|
||||
if (
|
||||
torch.compiler.is_exporting()
|
||||
and self.prefer_deferred_runtime_asserts_over_guards
|
||||
):
|
||||
# it's fine to defer simple guards here without checking,
|
||||
# the _maybe_guard_rel() call above will set replacements if possible,
|
||||
# and so the result here will be statically known
|
||||
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
|
||||
else:
|
||||
# at this point, we've evaluated the concrete expr value, and have
|
||||
# flipped/negated the guard if necessary. Now we know what to guard
|
||||
# or defer to runtime assert on.
|
||||
@ -7642,11 +7640,6 @@ class ShapeEnv:
|
||||
)
|
||||
self.guards.append(guard)
|
||||
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
||||
else:
|
||||
# it's fine to defer simple guards here without checking,
|
||||
# the _maybe_guard_rel() call above will set replacements if possible,
|
||||
# and so the result here will be statically known
|
||||
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
|
||||
else:
|
||||
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
|
||||
|
||||
|
Reference in New Issue
Block a user