mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
More readable Z3 expressions printer. (#106643)
This PR makes Z3 expressions easier to read and understand by creating a custom printer for them. Z3 expressions can be printed in 2 forms: 1. Using the builtin `str(e)` function 2. Using the `e.sexpr()` method Problem is that (1) is a bit hard to read because its line breaks are not so intuitive. (2) is a bit nicer, but the `to_int` and `to_real` functions clutter things up. The custom printer is an improved `sexpr()` function: - Leaves everything in one line - Gets rid of `to_int` and `to_real` functions - Reconstruct the floor division operations - Merge commutative operation chains Pull Request resolved: https://github.com/pytorch/pytorch/pull/106643 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
26846546e8
commit
33e70e34a3
@ -44,11 +44,7 @@ from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
from torch.ao.quantization.quantize_fx import prepare_qat_fx
|
||||
from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
FloorDiv,
|
||||
Mod,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FUSED_SDPA,
|
||||
@ -56,7 +52,7 @@ from torch.testing._internal.common_cuda import (
|
||||
TEST_CUDA,
|
||||
TEST_MULTIGPU,
|
||||
)
|
||||
from torch.testing._internal.common_utils import freeze_rng_state, IS_FBCODE, TEST_Z3
|
||||
from torch.testing._internal.common_utils import freeze_rng_state, IS_FBCODE
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"])
|
||||
@ -6226,133 +6222,6 @@ def ___make_guard_fn():
|
||||
self.assertEqual(counter.frame_count, 1)
|
||||
self.assertEqual(counter.op_count, 9)
|
||||
|
||||
def _prepare_for_translation_validation(self):
|
||||
from torch.fx.experimental.validator import TranslationValidator
|
||||
|
||||
validator = TranslationValidator()
|
||||
|
||||
# SymPy symbols.
|
||||
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
|
||||
|
||||
# Z3 symbols.
|
||||
[validator.add_var(s, int) for s in (s0, s1, s2)]
|
||||
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
|
||||
|
||||
return (s0, s1, s2), (z0, z1, z2), validator
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_sympy_to_z3_translation(self):
|
||||
import z3
|
||||
from torch.fx.experimental.validator import SympyToZ3
|
||||
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
test_cases = [
|
||||
# Integer constants.
|
||||
(sympy.S.Zero, z3.IntVal(0)),
|
||||
(sympy.S.One, z3.IntVal(1)),
|
||||
(sympy.S.NegativeOne, z3.IntVal(-1)),
|
||||
(sympy.Integer(2), z3.IntVal(2)),
|
||||
(
|
||||
s0,
|
||||
z0,
|
||||
),
|
||||
# Arithmetic operations.
|
||||
*[
|
||||
(op(s0, s1), op(z0, z1))
|
||||
for op in (
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.pow,
|
||||
)
|
||||
],
|
||||
# Logical operations.
|
||||
*[
|
||||
(sympy_op(s0, s1), z3_op(z0, z1))
|
||||
for sympy_op, z3_op in (
|
||||
(sympy.Eq, operator.eq),
|
||||
(sympy.Ne, operator.ne),
|
||||
(sympy.Lt, operator.lt),
|
||||
(sympy.Le, operator.le),
|
||||
(sympy.Gt, operator.gt),
|
||||
(sympy.Ge, operator.ge),
|
||||
)
|
||||
],
|
||||
# Other operations.
|
||||
(
|
||||
s0 - s1,
|
||||
z0 + z3.IntVal(-1) * z1,
|
||||
),
|
||||
(
|
||||
s0 / s1,
|
||||
z3.ToReal(z0) * (z1**-1),
|
||||
),
|
||||
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
|
||||
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
|
||||
(
|
||||
Mod(s2, (s0 / s1)),
|
||||
z2
|
||||
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
|
||||
* (z3.ToReal(z0) * z1**-1),
|
||||
),
|
||||
(
|
||||
Mod(s2, s0**3),
|
||||
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
|
||||
),
|
||||
]
|
||||
|
||||
toZ3 = SympyToZ3(validator)
|
||||
for sympy_expr, z3_expr in test_cases:
|
||||
result = toZ3.run(sympy_expr)
|
||||
self.assertTrue(
|
||||
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_translation_validation_sat(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
# Solutions for target is a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
validator.add_target_expr(s1 > s0**2)
|
||||
|
||||
r = validator.validate()
|
||||
self.assertEqual(r.success, True, msg=f"failed with model: {r.model}")
|
||||
self.assertIsNone(r.model)
|
||||
self.assertIsNone(r.failed_source_expr)
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_translation_validation_unsat(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
# Solutions for target is NOT a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
# This expression is less restrictive than its counterpart.
|
||||
validator.add_target_expr(s1 > s0 + 2)
|
||||
|
||||
r = validator.validate()
|
||||
self.assertEqual(r.success, False, msg=f"failed with model: {r.model}")
|
||||
self.assertIsNotNone(r.model)
|
||||
self.assertIsNotNone(r.failed_source_expr)
|
||||
|
||||
def test_simple_set_usage(self):
|
||||
def foo(x, y):
|
||||
setty = {x, y}
|
||||
|
||||
@ -5,6 +5,7 @@ import numbers
|
||||
import operator
|
||||
import pickle
|
||||
import sys
|
||||
import sympy
|
||||
import tempfile
|
||||
import unittest
|
||||
from types import BuiltinFunctionType
|
||||
@ -48,7 +49,7 @@ from torch.testing._internal.common_device_type import (
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_nn import module_tests, new_module_tests
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
try:
|
||||
@ -1700,6 +1701,174 @@ class TestModule(torch.nn.Module):
|
||||
self.assertIs(kwargs["input"], inp1)
|
||||
self.assertIs(kwargs["the_template"], inp2)
|
||||
|
||||
|
||||
class TestTranslationValidator(TestCase):
|
||||
def _prepare_for_translation_validation(self):
|
||||
from torch.fx.experimental.validator import TranslationValidator
|
||||
|
||||
validator = TranslationValidator()
|
||||
|
||||
# SymPy symbols.
|
||||
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
|
||||
|
||||
# Z3 symbols.
|
||||
[validator.add_var(s, int) for s in (s0, s1, s2)]
|
||||
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
|
||||
|
||||
return (s0, s1, s2), (z0, z1, z2), validator
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_sympy_to_z3_translation(self):
|
||||
import z3
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod
|
||||
from torch.fx.experimental.validator import SympyToZ3
|
||||
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
test_cases = [
|
||||
# Integer constants.
|
||||
(sympy.S.Zero, z3.IntVal(0)),
|
||||
(sympy.S.One, z3.IntVal(1)),
|
||||
(sympy.S.NegativeOne, z3.IntVal(-1)),
|
||||
(sympy.Integer(2), z3.IntVal(2)),
|
||||
(
|
||||
s0,
|
||||
z0,
|
||||
),
|
||||
# Arithmetic operations.
|
||||
*[
|
||||
(op(s0, s1), op(z0, z1))
|
||||
for op in (
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.pow,
|
||||
)
|
||||
],
|
||||
# Logical operations.
|
||||
*[
|
||||
(sympy_op(s0, s1), z3_op(z0, z1))
|
||||
for sympy_op, z3_op in (
|
||||
(sympy.Eq, operator.eq),
|
||||
(sympy.Ne, operator.ne),
|
||||
(sympy.Lt, operator.lt),
|
||||
(sympy.Le, operator.le),
|
||||
(sympy.Gt, operator.gt),
|
||||
(sympy.Ge, operator.ge),
|
||||
)
|
||||
],
|
||||
# Other operations.
|
||||
(
|
||||
s0 - s1,
|
||||
z0 + z3.IntVal(-1) * z1,
|
||||
),
|
||||
(
|
||||
s0 / s1,
|
||||
z3.ToReal(z0) * (z1**-1),
|
||||
),
|
||||
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
|
||||
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
|
||||
(
|
||||
Mod(s2, (s0 / s1)),
|
||||
z2
|
||||
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
|
||||
* (z3.ToReal(z0) * z1**-1),
|
||||
),
|
||||
(
|
||||
Mod(s2, s0**3),
|
||||
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
|
||||
),
|
||||
]
|
||||
|
||||
toZ3 = SympyToZ3(validator)
|
||||
for sympy_expr, z3_expr in test_cases:
|
||||
result = toZ3.run(sympy_expr)
|
||||
self.assertTrue(
|
||||
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_translation_validation_sat(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
# Solutions for target is a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
validator.add_target_expr(s1 > s0**2)
|
||||
|
||||
validator.validate()
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_translation_validation_unsat(self):
|
||||
from torch.fx.experimental.validator import ValidationException
|
||||
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
# Solutions for target is NOT a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
# This expression is less restrictive than its counterpart.
|
||||
validator.add_target_expr(s1 > s0 + 2)
|
||||
|
||||
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
|
||||
validator.validate()
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_z3str(self):
|
||||
import z3
|
||||
from torch.fx.experimental.validator import z3str
|
||||
|
||||
a = z3.Int("a")
|
||||
b = z3.Int("b")
|
||||
special = z3.Real("this.size()[2]")
|
||||
|
||||
test_cases = [
|
||||
(z3.IntVal(42), "42"),
|
||||
# Variable.
|
||||
(a, "a"),
|
||||
# Name with special characters.
|
||||
(special, "this.size()[2]"),
|
||||
# Renamed function fpplications.
|
||||
(a != b, "(!= a b)"),
|
||||
(a ** b, "(pow a b)"),
|
||||
# Chain of associative operations.
|
||||
*[
|
||||
(op(op(a, 5), b), f"({opstr} 5 a b)")
|
||||
for op, opstr in [
|
||||
(operator.add, "+"),
|
||||
(operator.mul, "*")
|
||||
]
|
||||
],
|
||||
# Revert 'Not' conversions.
|
||||
(a != b, "(!= a b)"),
|
||||
(a < b, "(> b a)"),
|
||||
(a > b, "(> a b)"),
|
||||
# Ignore 'ToInt' and 'ToReal' functions.
|
||||
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
|
||||
(z3.ToReal(a + b), "(+ a b)"),
|
||||
# Convert to floor division: 'idiv'.
|
||||
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
|
||||
]
|
||||
|
||||
for expr, expected in test_cases:
|
||||
self.assertEqual(z3str(expr), expected)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestNormalizeOperators, globals())
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -2012,39 +2012,8 @@ class ShapeEnv:
|
||||
self.validator.add_assertion(expr)
|
||||
|
||||
def _check_translation_validate(self) -> None:
|
||||
if not _translation_validation_enabled():
|
||||
return
|
||||
|
||||
result = self.validator.validate()
|
||||
|
||||
if result.success:
|
||||
return
|
||||
|
||||
if result.model is None:
|
||||
reason = "no answer"
|
||||
source_exprs = self.validator._source_exprs
|
||||
failed = ""
|
||||
else:
|
||||
assert result.failed_source_expr is not None
|
||||
reason = "model: %s" % {sym: result.model[sym] for sym in result.model}
|
||||
source_exprs = result.failed_source_expr
|
||||
failed = "Failed "
|
||||
|
||||
def exprs_to_str(exprs):
|
||||
return "\n".join(f"==> {e}" for e in exprs)
|
||||
|
||||
assertions = self.validator._assertions
|
||||
target_exprs = self.validator._target_exprs
|
||||
|
||||
raise RuntimeError(f"""translation validation failed with {reason}.
|
||||
Assertions:
|
||||
{exprs_to_str(assertions)}
|
||||
|
||||
Target Guards:
|
||||
{exprs_to_str(target_exprs)}
|
||||
|
||||
{failed}Source Guards:
|
||||
{exprs_to_str(source_exprs)}""")
|
||||
if _translation_validation_enabled():
|
||||
self.validator.validate()
|
||||
|
||||
def create_fx_call_function(
|
||||
self,
|
||||
|
||||
@ -5,9 +5,10 @@ import operator
|
||||
import sympy
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.exc import TorchDynamoException
|
||||
import torch.fx
|
||||
import torch.fx.traceback as fx_traceback
|
||||
|
||||
@ -42,6 +43,101 @@ try:
|
||||
# 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
|
||||
# (see [Note: TranslationValidator])
|
||||
|
||||
# Better Z3 to string implementation (for a small fraction of Z3).
|
||||
#
|
||||
# Here are the things we clean before showing the Z3 expression:
|
||||
# - Rename a few ops (e.g. "Distinct" ==> "!=")
|
||||
#
|
||||
# - Ignore ToInt and ToReal operations:
|
||||
# usually they don't really matter
|
||||
#
|
||||
# - Transform (ToInt (/ ...)) into (idiv ...):
|
||||
# this is the pattern for floor division
|
||||
#
|
||||
# - Collect a chain of the same operations into one
|
||||
def z3str(e: z3.ExprRef) -> str:
|
||||
assert z3.is_expr(e), f"unsupported expression type: {e}"
|
||||
|
||||
def get_args_str(e: z3.ExprRef) -> List[str]:
|
||||
return [z3str(e.arg(i)) for i in range(e.num_args())]
|
||||
|
||||
# First, we simplify the given expression.
|
||||
# This is done using rewriting rules, so shouldn't take long.
|
||||
e = z3.simplify(e)
|
||||
|
||||
|
||||
# Only support function applications.
|
||||
# Even Z3 "variables" are, in fact, function applications.
|
||||
if not z3.is_app(e):
|
||||
raise ValueError(f"can't print Z3 expression: {e}")
|
||||
|
||||
if z3.is_int_value(e) or z3.is_rational_value(e):
|
||||
return e.as_string() # type: ignore[attr-defined]
|
||||
|
||||
decl = e.decl()
|
||||
kind = decl.kind()
|
||||
op = str(decl)
|
||||
args = get_args_str(e)
|
||||
|
||||
if kind == z3.Z3_OP_POWER:
|
||||
op = "pow"
|
||||
|
||||
elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
|
||||
# Collect the arguments of chains of ADD and MUL.
|
||||
# This is safe, since they are associative.
|
||||
|
||||
def collect_str_args(e):
|
||||
if not (z3.is_app(e) and e.decl().kind() == kind):
|
||||
return [z3str(e)]
|
||||
else:
|
||||
return [
|
||||
x
|
||||
for i in range(e.num_args())
|
||||
for x in collect_str_args(e.arg(i))
|
||||
]
|
||||
|
||||
args = collect_str_args(e)
|
||||
|
||||
elif kind == z3.Z3_OP_NOT:
|
||||
# Revert some conversions that z3.simplify applies:
|
||||
# - a != b ==> (Not (== a b)) ==> (!= a b)
|
||||
# - a < b ==> (Not (<= b a)) ==> (> b a)
|
||||
# - a > b ==> (Not (<= a b)) ==> (> a b)
|
||||
|
||||
assert e.num_args() == 1
|
||||
arg = e.arg(0)
|
||||
|
||||
assert z3.is_app(arg)
|
||||
argkind = arg.decl().kind()
|
||||
|
||||
logic_inverse = {
|
||||
z3.Z3_OP_EQ: "!=",
|
||||
z3.Z3_OP_LE: ">",
|
||||
z3.Z3_OP_GE: "<",
|
||||
}
|
||||
|
||||
if argkind in logic_inverse:
|
||||
op = logic_inverse[argkind]
|
||||
args = get_args_str(arg)
|
||||
|
||||
elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
|
||||
assert e.num_args() == 1
|
||||
argstr = z3str(e.arg(0))
|
||||
|
||||
# Check if it's the floor division pattern.
|
||||
if argstr.startswith("(/"):
|
||||
return "(idiv" + argstr[2:]
|
||||
|
||||
# Otherwise, just ignore it.
|
||||
return argstr
|
||||
|
||||
elif kind == z3.Z3_OP_UNINTERPRETED:
|
||||
assert e.num_args() == 0
|
||||
return str(decl)
|
||||
|
||||
string = op + " " + " ".join(args)
|
||||
return f"({string.rstrip()})"
|
||||
|
||||
# Implementation of Python semantics as Z3 expressions.
|
||||
#
|
||||
# Z3 Real-Int theory has operators with semantics that differ that of
|
||||
@ -363,24 +459,13 @@ try:
|
||||
assert isinstance(ref, z3.BoolRef)
|
||||
self._assertions.add(ref)
|
||||
|
||||
# The result of a validation run.
|
||||
@dataclass
|
||||
class Result:
|
||||
success: bool
|
||||
|
||||
# Mapping of the name of each free variable to the value assigned to it.
|
||||
model: Optional[z3.ModelRef] = None
|
||||
|
||||
# List of the source expressions that failed due to the assignment.
|
||||
failed_source_expr: Optional[List[z3.BoolRef]] = None
|
||||
|
||||
def validate(self) -> "TranslationValidator.Result":
|
||||
def validate(self) -> None:
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
|
||||
if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
|
||||
# If there are no source/target expressions, there's nothing we really
|
||||
# wish to prove. So, we just return.
|
||||
return self.Result(success=True)
|
||||
return None
|
||||
|
||||
# Here, we use "QF_NRA" logic for the solver:
|
||||
# "Quantifier-free Non-linear Real Arithmetic".
|
||||
@ -411,10 +496,11 @@ try:
|
||||
# Target expressions are unsound.
|
||||
# Log the found model and the source expressions that failed.
|
||||
model = solver.model()
|
||||
return self.Result(
|
||||
success=False,
|
||||
model=model,
|
||||
failed_source_expr=[inp for inp in self._source_exprs if not model.evaluate(inp)],
|
||||
raise ValidationException(
|
||||
model, self._assertions, self._target_exprs,
|
||||
failed_source_exprs=[
|
||||
inp for inp in self._source_exprs if not model.evaluate(inp)
|
||||
]
|
||||
)
|
||||
else:
|
||||
if r == z3.unknown:
|
||||
@ -426,7 +512,32 @@ try:
|
||||
# Target expressions are sound.
|
||||
assert r == z3.unsat
|
||||
log.debug("translation validation: success")
|
||||
return self.Result(success=True)
|
||||
|
||||
|
||||
class ValidationException(TorchDynamoException):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
assertions: Iterable[z3.ExprRef],
|
||||
target_exprs: Iterable[z3.ExprRef],
|
||||
failed_source_exprs: Iterable[z3.ExprRef]
|
||||
) -> None:
|
||||
model_str = self._joinlines(model, lambda sym: f"{sym}: {model[sym]}")
|
||||
assertions_str = self._joinlines(assertions, z3str)
|
||||
target_exprs_str = self._joinlines(target_exprs, z3str)
|
||||
failed_source_exprs_str = self._joinlines(failed_source_exprs, z3str)
|
||||
|
||||
super().__init__(
|
||||
"translation validation failed.\n\n"
|
||||
"Model:\n" + model_str + "\n\n"
|
||||
"Assertions:\n" + assertions_str + "\n\n"
|
||||
"Target Expressions:\n" + target_exprs_str + "\n\n"
|
||||
"Failed Source Expressions:\n" + failed_source_exprs_str + "\n\n"
|
||||
)
|
||||
|
||||
def _joinlines(self, xs: Iterable[Any], f: Callable[[Any], str] = lambda x: x) -> str:
|
||||
return "\n".join(f" ==> {f(x)}" for x in xs)
|
||||
|
||||
except ImportError:
|
||||
_HAS_Z3 = False
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user