Run translation validation on tracing error. (#106645)

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106645
Approved by: https://github.com/ezyang
This commit is contained in:
Yukio Siraichi
2023-08-12 18:07:55 -03:00
committed by PyTorch MergeBot
parent 937cd3742b
commit d8ad74857c
6 changed files with 365 additions and 271 deletions

View File

@ -9,7 +9,8 @@ import torch._dynamo.config
import torch._dynamo.test_case
from torch._dynamo.comptime import comptime
from torch._dynamo.exc import Unsupported
from torch.testing._internal.common_utils import munge_exc
from torch.testing._internal.common_device_type import skipIf
from torch.testing._internal.common_utils import munge_exc, TEST_Z3
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
@ -189,6 +190,53 @@ backend='relu_compile_error_TESTING_ONLY' raised:
ReluCompileError:""",
)
@skipIf(not TEST_Z3, "z3 not installed")
@torch._dynamo.config.patch(
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
assume_static_by_default=False,
translation_validation=True,
suppress_errors=False,
)
def test_trigger_on_error(self):
from torch.fx.experimental.validator import ValidationException
@torch.compile
def fn(x):
return x.reshape(-1, 4)
self.assertExpectedInlineMunged(
ValidationException,
lambda: fn(torch.randn(20)),
"""\
translation validation failed.
Model:
==> L['x'].storage_offset(): 0
==> s0: 4
==> L['x'].stride()[0]: 1
==> L['x'].size()[0]: 4
Assertions:
==> (== L['x'].size()[0] s0)
==> (> s0 1)
==> (Not (And (< L['x'].size()[0] 4) (>= L['x'].size()[0] 0)))
==> (== 0 L['x'].storage_offset())
==> (== 1 L['x'].stride()[0])
==> (True)
Target Expressions:
==> (>= 9223372036854775806 s0)
==> (== 4 L['x'].size()[0])
==> (== 0 L['x'].storage_offset())
==> (> s0 0)
==> (== 1 L['x'].stride()[0])
==> (<= 2 s0)
==> (== 4 s0)
Failed Source Expressions:
==> (!= 4 L['x'].size()[0])""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -1702,171 +1702,165 @@ class TestModule(torch.nn.Module):
self.assertIs(kwargs["the_template"], inp2)
class TestTranslationValidator(TestCase):
def _prepare_for_translation_validation(self):
from torch.fx.experimental.validator import TranslationValidator
if TEST_Z3:
import z3
validator = TranslationValidator()
import torch._dynamo.config
# SymPy symbols.
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
from torch.utils._sympy.functions import FloorDiv, Mod
# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
class TestTranslationValidation(TestCase):
def _prepare_for_translation_validation(self):
validator = TranslationValidator()
return (s0, s1, s2), (z0, z1, z2), validator
# SymPy symbols.
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_sympy_to_z3_translation(self):
import z3
from torch.utils._sympy.functions import FloorDiv, Mod
from torch.fx.experimental.validator import SympyToZ3
# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
return (s0, s1, s2), (z0, z1, z2), validator
def test_sympy_to_z3(self):
test_cases = [
# Integer constants.
(sympy.S.Zero, z3.IntVal(0)),
(sympy.S.One, z3.IntVal(1)),
(sympy.S.NegativeOne, z3.IntVal(-1)),
(sympy.Integer(2), z3.IntVal(2)),
(
s0,
z0,
),
# Arithmetic operations.
*[
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mul,
operator.pow,
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
test_cases = [
# Integer constants.
(sympy.S.Zero, z3.IntVal(0)),
(sympy.S.One, z3.IntVal(1)),
(sympy.S.NegativeOne, z3.IntVal(-1)),
(sympy.Integer(2), z3.IntVal(2)),
(
s0,
z0,
),
# Arithmetic operations.
*[
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mul,
operator.pow,
)
],
# Logical operations.
*[
(sympy_op(s0, s1), z3_op(z0, z1))
for sympy_op, z3_op in (
(sympy.Eq, operator.eq),
(sympy.Ne, operator.ne),
(sympy.Lt, operator.lt),
(sympy.Le, operator.le),
(sympy.Gt, operator.gt),
(sympy.Ge, operator.ge),
)
],
# Other operations.
(
s0 - s1,
z0 + z3.IntVal(-1) * z1,
),
(
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
Mod(s2, (s0 / s1)),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
Mod(s2, s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]
toZ3 = SympyToZ3(validator)
for sympy_expr, z3_expr in test_cases:
result = toZ3.run(sympy_expr)
self.assertTrue(
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
)
],
# Logical operations.
*[
(sympy_op(s0, s1), z3_op(z0, z1))
for sympy_op, z3_op in (
(sympy.Eq, operator.eq),
(sympy.Ne, operator.ne),
(sympy.Lt, operator.lt),
(sympy.Le, operator.le),
(sympy.Gt, operator.gt),
(sympy.Ge, operator.ge),
)
],
# Other operations.
def test_sat(self):
(
s0 - s1,
z0 + z3.IntVal(-1) * z1,
),
(
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
Mod(s2, (s0 / s1)),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
Mod(s2, s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
toZ3 = SympyToZ3(validator)
for sympy_expr, z3_expr in test_cases:
result = toZ3.run(sympy_expr)
self.assertTrue(
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
)
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_sat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
# Solutions for target is a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
validator.add_target_expr(s1 > s0**2)
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
# Solutions for target is a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
validator.add_target_expr(s1 > s0**2)
validator.validate()
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_unsat(self):
from torch.fx.experimental.validator import ValidationException
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
# Solutions for target is NOT a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
# This expression is less restrictive than its counterpart.
validator.add_target_expr(s1 > s0 + 2)
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
validator.validate()
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_z3str(self):
import z3
from torch.fx.experimental.validator import z3str
def test_unsat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
a = z3.Int("a")
b = z3.Int("b")
special = z3.Real("this.size()[2]")
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
test_cases = [
(z3.IntVal(42), "42"),
# Variable.
(a, "a"),
# Name with special characters.
(special, "this.size()[2]"),
# Renamed function fpplications.
(a != b, "(!= a b)"),
(a ** b, "(pow a b)"),
# Chain of associative operations.
*[
(op(op(a, 5), b), f"({opstr} 5 a b)")
for op, opstr in [
(operator.add, "+"),
(operator.mul, "*")
]
],
# Revert 'Not' conversions.
(a != b, "(!= a b)"),
(a < b, "(> b a)"),
(a > b, "(> a b)"),
# Ignore 'ToInt' and 'ToReal' functions.
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
(z3.ToReal(a + b), "(+ a b)"),
# Convert to floor division: 'idiv'.
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
]
# Solutions for target is NOT a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
# This expression is less restrictive than its counterpart.
validator.add_target_expr(s1 > s0 + 2)
for expr, expected in test_cases:
self.assertEqual(z3str(expr), expected)
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
validator.validate()
def test_z3str(self):
a = z3.Int("a")
b = z3.Int("b")
special = z3.Real("this.size()[2]")
test_cases = [
(z3.IntVal(42), "42"),
# Variable.
(a, "a"),
# Name with special characters.
(special, "this.size()[2]"),
# Renamed function fpplications.
(a != b, "(!= a b)"),
(a ** b, "(pow a b)"),
# Chain of associative operations.
*[
(op(op(a, 5), b), f"({opstr} 5 a b)")
for op, opstr in [
(operator.add, "+"),
(operator.mul, "*")
]
],
# Revert 'Not' conversions.
(a != b, "(!= a b)"),
(a < b, "(> b a)"),
(a > b, "(> a b)"),
# Ignore 'ToInt' and 'ToReal' functions.
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
(z3.ToReal(a + b), "(+ a b)"),
# Convert to floor division: 'idiv'.
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
]
for expr, expected in test_cases:
self.assertEqual(z3str(expr), expected)
instantiate_device_type_tests(TestNormalizeOperators, globals())

View File

@ -290,6 +290,11 @@ capture_func_transforms = True
# used for testing
inject_BUILD_SET_unimplemented_TESTING_ONLY = False
# wraps (un)equalities with 'Not' class after recording the correct expression
# in the FX graph. This should incorrectly construct the divisible and replacement
# lists, and incorrectly issue guards.
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
_autograd_backward_strict_mode_banned_ops = [
"stride",
"requires_grad",

View File

@ -400,6 +400,11 @@ def _compile(
frame: Optional[types.FrameType] = None,
frame_state=None,
) -> Optional[GuardedCode]:
from torch.fx.experimental.validator import (
translation_validation_enabled,
ValidationException,
)
output: Optional[OutputGraph] = None
# This is shared across restarts
mutated_closure_cell_contents: Set[str] = set()
@ -421,8 +426,21 @@ def _compile(
mutated_closure_cell_contents,
frame_state=frame_state,
)
with tracing(tracer.output.tracing_context):
tracer.run()
try:
with tracing(tracer.output.tracing_context):
tracer.run()
except (exc.RestartAnalysis, exc.SkipFrame):
raise
except Exception:
if translation_validation_enabled():
fakes = tracer.output.tracked_fakes
tracer.output.shape_env.produce_guards(
[a.fake for a in fakes],
[a.source for a in fakes],
)
raise
output = tracer.output
assert output is not None
assert output.output_instructions
@ -542,6 +560,7 @@ def _compile(
AssertionError,
ConstraintViolationError,
GuardOnDataDependentSymNode,
ValidationException,
) as e:
fail_reason = str(e)
exception_handler(e, code, frame, export=export)

View File

@ -2072,6 +2072,10 @@ class ShapeEnv:
return self.fx_node_cache[node_key]
def remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
if _translation_validation_enabled() and node is not None:
self.graph.erase_node(node)
def _suppress_guards_tls(self):
return getattr(TLS, "suppress_guards", False)
@ -3168,119 +3172,144 @@ class ShapeEnv:
#
# If all of the above check, we create an FX node representing the
# actual expression to be guarded.
node = None
if (
_translation_validation_enabled()
and fx_node is not None
and not self._suppress_guards_tls()
):
if concrete_val is sympy.true:
self.create_fx_call_function(torch._assert, (fx_node,))
node = self.create_fx_call_function(torch._assert, (fx_node,))
elif concrete_val is sympy.false:
neg = self.create_fx_call_function(operator.not_, (fx_node,))
self.create_fx_call_function(torch._assert, (neg,))
node = self.create_fx_call_function(torch._assert, (neg,))
else:
eql = self.create_fx_call_function(operator.eq, (fx_node, concrete_val))
self.create_fx_call_function(torch._assert, (eql,))
node = self.create_fx_call_function(torch._assert, (eql,))
if len(orig_expr.free_symbols) == 0:
self.log.debug("eval %s [trivial]", orig_expr)
# NB: don't test float as there may be precision issues
if isinstance(hint, (int, bool)):
assert orig_expr == hint, f"{orig_expr} != {hint}"
return orig_expr
# After creating the FX node corresponding to orig_expr, we must make sure that
# no error will be raised until the end of this function.
#
# Reason: the translation validation may become invalid otherwise.
#
# If an error is raised before the end of this function, we remove the FX node
# inserted, and re-raise the error.
guard = None
tb = None
expr = orig_expr
try:
if len(orig_expr.free_symbols) == 0:
self.log.debug("eval %s [trivial]", orig_expr)
# NB: don't test float as there may be precision issues
if isinstance(hint, (int, bool)):
assert orig_expr == hint, f"{orig_expr} != {hint}"
return orig_expr
static_expr = self._maybe_evaluate_static(expr)
if static_expr is not None:
self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
# NB: don't test float as there may be precision issues
if isinstance(hint, (int, bool)):
assert static_expr == hint, f"{static_expr} != {hint}"
return static_expr
expr = orig_expr
if not (expr.free_symbols <= self.var_to_val.keys()):
# TODO: dedupe this with _maybe_evaluate_static
# Attempt to eliminate the unbacked SymInt
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
if not (new_expr.free_symbols <= self.var_to_val.keys()):
raise self._make_data_dependent_error(expr.xreplace(self.var_to_val), expr)
expr = new_expr
static_expr = self._maybe_evaluate_static(expr)
if static_expr is not None:
self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
# NB: don't test float as there may be precision issues
if isinstance(hint, (int, bool)):
assert static_expr == hint, f"{static_expr} != {hint}"
return static_expr
if self.frozen:
self.counter["ignored_backward_guard"] += 1
signpost_event(
"dynamic",
"evaluate_expr_frozen",
{
**self.co_fields,
"ignored_guard": f"{expr} == {concrete_val}",
# no version = original state (this signpost is expected)
# version 2 = dynamic backwards is eagerly compiled
"version": 2,
},
)
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)
if not (expr.free_symbols <= self.var_to_val.keys()):
# TODO: dedupe this with _maybe_evaluate_static
# Attempt to eliminate the unbacked SymInt
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
if not (new_expr.free_symbols <= self.var_to_val.keys()):
raise self._make_data_dependent_error(expr.xreplace(self.var_to_val), expr)
expr = new_expr
if isinstance(expr, (sympy.Eq, sympy.Ne)):
self._maybe_guard_eq(expr, bool(concrete_val))
# TODO: If we successfully eliminate a symbol via equality, it
# is not actually necessary to save a guard for the equality,
# as we will implicitly generate a guard when we match that
# input against the symbol
elif isinstance(concrete_val, sympy.Integer):
# WARNING: we cannot actually do simplifications on guards
# on floating point values, because Sympy generally does not
# think expressions on integers can ever be equal to floating
# point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False). Without
# very clear algebraic laws that hold for floating point, such
# simplifications are error prone anyway, so be sure not to
# maybe_guard_eq in those cases.
self._maybe_guard_eq(sympy.Eq(expr, concrete_val), True)
if concrete_val is sympy.true:
g = expr
elif concrete_val is sympy.false:
g = sympy.Not(expr)
else:
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
if not self._suppress_guards_tls():
tb = traceback.extract_stack()[:-1]
stack = ''.join(traceback.format_list(tb))
guard = ShapeGuard(g, stack)
self.guards.append(guard)
self.refine_ranges(guard)
if self.log.isEnabledFor(logging.INFO):
for frame in reversed(tb):
if frame.filename not in uninteresting_files():
break
# NB: this stack is truncated, but it's fine because the main
# stack_info will give you the rest of the info you need
maybe_user_loc = ""
user_tb = TracingContext.extract_stack()
if user_tb:
maybe_user_loc = " at " + format_frame(user_tb[-1])
is_debug = self.log.isEnabledFor(logging.DEBUG)
maybe_extra_debug = ""
if is_debug and user_tb:
maybe_extra_debug = (
'\nUser Stack (most recent call last):\n' +
' (snipped, see stack below for prefix)\n' +
''.join(traceback.format_list(user_tb))
)
self.log.info(
"eval %s [guard added]%s (%s)%s",
g,
maybe_user_loc,
format_frame(frame),
maybe_extra_debug,
stack_info=is_debug,
if self.frozen:
self.counter["ignored_backward_guard"] += 1
signpost_event(
"dynamic",
"evaluate_expr_frozen",
{
**self.co_fields,
"ignored_guard": f"{expr} == {concrete_val}",
# no version = original state (this signpost is expected)
# version 2 = dynamic backwards is eagerly compiled
"version": 2,
},
)
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)
if torch._dynamo.config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY and isinstance(hint, bool):
if isinstance(expr, (sympy.Eq, sympy.Ne)):
expr = sympy.Not(expr)
if isinstance(expr, (sympy.Eq, sympy.Ne)):
self._maybe_guard_eq(expr, bool(concrete_val))
# TODO: If we successfully eliminate a symbol via equality, it
# is not actually necessary to save a guard for the equality,
# as we will implicitly generate a guard when we match that
# input against the symbol
elif isinstance(concrete_val, sympy.Integer):
# WARNING: we cannot actually do simplifications on guards
# on floating point values, because Sympy generally does not
# think expressions on integers can ever be equal to floating
# point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False). Without
# very clear algebraic laws that hold for floating point, such
# simplifications are error prone anyway, so be sure not to
# maybe_guard_eq in those cases.
self._maybe_guard_eq(sympy.Eq(expr, concrete_val), True)
if concrete_val is sympy.true:
g = expr
elif concrete_val is sympy.false:
g = sympy.Not(expr)
else:
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
if not self._suppress_guards_tls():
tb = traceback.extract_stack()[:-1]
stack = ''.join(traceback.format_list(tb))
guard = ShapeGuard(g, stack)
self.guards.append(guard)
except Exception:
self.remove_fx_node(node)
raise
else:
self.log.debug("eval %s [guard suppressed]", g)
if not self._suppress_guards_tls():
assert guard is not None
assert tb is not None
self.refine_ranges(guard)
if self.log.isEnabledFor(logging.INFO):
for frame in reversed(tb):
if frame.filename not in uninteresting_files():
break
# NB: this stack is truncated, but it's fine because the main
# stack_info will give you the rest of the info you need
maybe_user_loc = ""
user_tb = TracingContext.extract_stack()
if user_tb:
maybe_user_loc = " at " + format_frame(user_tb[-1])
is_debug = self.log.isEnabledFor(logging.DEBUG)
maybe_extra_debug = ""
if is_debug and user_tb:
maybe_extra_debug = (
'\nUser Stack (most recent call last):\n' +
' (snipped, see stack below for prefix)\n' +
''.join(traceback.format_list(user_tb))
)
self.log.info(
"eval %s [guard added]%s (%s)%s",
g,
maybe_user_loc,
format_frame(frame),
maybe_extra_debug,
stack_info=is_debug,
)
else:
self.log.debug("eval %s [guard suppressed]", g)
return concrete_val

View File

@ -5,7 +5,7 @@ import operator
import sympy
from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
import torch
from torch._dynamo.exc import TorchDynamoException
@ -524,31 +524,6 @@ try:
assert r == z3.unsat
log.debug("translation validation: success")
class ValidationException(TorchDynamoException):
def __init__(
self,
model,
assertions: Iterable[z3.ExprRef],
target_exprs: Iterable[z3.ExprRef],
failed_source_exprs: Iterable[z3.ExprRef]
) -> None:
model_str = self._joinlines(model, lambda sym: f"{sym}: {model[sym]}")
assertions_str = self._joinlines(assertions, z3str)
target_exprs_str = self._joinlines(target_exprs, z3str)
failed_source_exprs_str = self._joinlines(failed_source_exprs, z3str)
super().__init__(
"translation validation failed.\n\n"
"Model:\n" + model_str + "\n\n"
"Assertions:\n" + assertions_str + "\n\n"
"Target Expressions:\n" + target_exprs_str + "\n\n"
"Failed Source Expressions:\n" + failed_source_exprs_str + "\n\n"
)
def _joinlines(self, xs: Iterable[Any], f: Callable[[Any], str] = lambda x: x) -> str:
return "\n".join(f" ==> {f(x)}" for x in xs)
except ImportError:
_HAS_Z3 = False
else:
@ -573,5 +548,29 @@ def assert_z3_installed_if_tv_set():
"z3-solver or disable translation validation."
)
class ValidationException(TorchDynamoException):
def __init__(self, model, assertions, target_exprs, failed_source_exprs):
assert _HAS_Z3
def symbolstr(sym) -> str:
return f"{sym}: {model[sym]}"
def joinlines(xs) -> str:
return "\n".join(f" ==> {x}" for x in xs)
model_str = joinlines(map(symbolstr, model))
assertions_str = joinlines(map(z3str, assertions))
target_exprs_str = joinlines(map(z3str, target_exprs))
failed_source_exprs_str = joinlines(map(z3str, failed_source_exprs))
super().__init__(
"translation validation failed.\n\n"
"Model:\n" + model_str + "\n\n"
"Assertions:\n" + assertions_str + "\n\n"
"Target Expressions:\n" + target_exprs_str + "\n\n"
"Failed Source Expressions:\n" + failed_source_exprs_str
)
# Checks when this module is loaded.
assert_z3_installed_if_tv_set()