More readable Z3 expressions printer. (#106643)

This PR makes Z3 expressions easier to read and understand by creating a custom printer
for them.

Z3 expressions can be printed in 2 forms:

1. Using the builtin `str(e)` function
2. Using the `e.sexpr()` method

Problem is that (1) is a bit hard to read because its line breaks are not so
intuitive. (2) is a bit nicer, but the `to_int` and `to_real` functions clutter things up.

The custom printer is an improved `sexpr()` function:

- Leaves everything in one line
- Gets rid of `to_int` and `to_real` functions
- Reconstruct the floor division operations
- Merge commutative operation chains

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106643
Approved by: https://github.com/ezyang
This commit is contained in:
Yukio Siraichi
2023-08-05 16:51:55 -03:00
committed by PyTorch MergeBot
parent 26846546e8
commit 33e70e34a3
4 changed files with 304 additions and 186 deletions

View File

@ -44,11 +44,7 @@ from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.quantize_fx import prepare_qat_fx
from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
FloorDiv,
Mod,
)
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
from torch.nn import functional as F
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_SDPA,
@ -56,7 +52,7 @@ from torch.testing._internal.common_cuda import (
TEST_CUDA,
TEST_MULTIGPU,
)
from torch.testing._internal.common_utils import freeze_rng_state, IS_FBCODE, TEST_Z3
from torch.testing._internal.common_utils import freeze_rng_state, IS_FBCODE
from torch.testing._internal.jit_utils import JitTestCase
mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"])
@ -6226,133 +6222,6 @@ def ___make_guard_fn():
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)
def _prepare_for_translation_validation(self):
from torch.fx.experimental.validator import TranslationValidator
validator = TranslationValidator()
# SymPy symbols.
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
return (s0, s1, s2), (z0, z1, z2), validator
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_sympy_to_z3_translation(self):
import z3
from torch.fx.experimental.validator import SympyToZ3
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
test_cases = [
# Integer constants.
(sympy.S.Zero, z3.IntVal(0)),
(sympy.S.One, z3.IntVal(1)),
(sympy.S.NegativeOne, z3.IntVal(-1)),
(sympy.Integer(2), z3.IntVal(2)),
(
s0,
z0,
),
# Arithmetic operations.
*[
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mul,
operator.pow,
)
],
# Logical operations.
*[
(sympy_op(s0, s1), z3_op(z0, z1))
for sympy_op, z3_op in (
(sympy.Eq, operator.eq),
(sympy.Ne, operator.ne),
(sympy.Lt, operator.lt),
(sympy.Le, operator.le),
(sympy.Gt, operator.gt),
(sympy.Ge, operator.ge),
)
],
# Other operations.
(
s0 - s1,
z0 + z3.IntVal(-1) * z1,
),
(
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
Mod(s2, (s0 / s1)),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
Mod(s2, s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]
toZ3 = SympyToZ3(validator)
for sympy_expr, z3_expr in test_cases:
result = toZ3.run(sympy_expr)
self.assertTrue(
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
)
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_sat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
# Solutions for target is a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
validator.add_target_expr(s1 > s0**2)
r = validator.validate()
self.assertEqual(r.success, True, msg=f"failed with model: {r.model}")
self.assertIsNone(r.model)
self.assertIsNone(r.failed_source_expr)
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_unsat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
# Solutions for target is NOT a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
# This expression is less restrictive than its counterpart.
validator.add_target_expr(s1 > s0 + 2)
r = validator.validate()
self.assertEqual(r.success, False, msg=f"failed with model: {r.model}")
self.assertIsNotNone(r.model)
self.assertIsNotNone(r.failed_source_expr)
def test_simple_set_usage(self):
def foo(x, y):
setty = {x, y}

View File

@ -5,6 +5,7 @@ import numbers
import operator
import pickle
import sys
import sympy
import tempfile
import unittest
from types import BuiltinFunctionType
@ -48,7 +49,7 @@ from torch.testing._internal.common_device_type import (
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_nn import module_tests, new_module_tests
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
from torch.testing._internal.jit_utils import JitTestCase
try:
@ -1700,6 +1701,174 @@ class TestModule(torch.nn.Module):
self.assertIs(kwargs["input"], inp1)
self.assertIs(kwargs["the_template"], inp2)
class TestTranslationValidator(TestCase):
def _prepare_for_translation_validation(self):
from torch.fx.experimental.validator import TranslationValidator
validator = TranslationValidator()
# SymPy symbols.
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
return (s0, s1, s2), (z0, z1, z2), validator
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_sympy_to_z3_translation(self):
import z3
from torch.utils._sympy.functions import FloorDiv, Mod
from torch.fx.experimental.validator import SympyToZ3
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
test_cases = [
# Integer constants.
(sympy.S.Zero, z3.IntVal(0)),
(sympy.S.One, z3.IntVal(1)),
(sympy.S.NegativeOne, z3.IntVal(-1)),
(sympy.Integer(2), z3.IntVal(2)),
(
s0,
z0,
),
# Arithmetic operations.
*[
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mul,
operator.pow,
)
],
# Logical operations.
*[
(sympy_op(s0, s1), z3_op(z0, z1))
for sympy_op, z3_op in (
(sympy.Eq, operator.eq),
(sympy.Ne, operator.ne),
(sympy.Lt, operator.lt),
(sympy.Le, operator.le),
(sympy.Gt, operator.gt),
(sympy.Ge, operator.ge),
)
],
# Other operations.
(
s0 - s1,
z0 + z3.IntVal(-1) * z1,
),
(
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
Mod(s2, (s0 / s1)),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
Mod(s2, s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]
toZ3 = SympyToZ3(validator)
for sympy_expr, z3_expr in test_cases:
result = toZ3.run(sympy_expr)
self.assertTrue(
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
)
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_sat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
# Solutions for target is a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
validator.add_target_expr(s1 > s0**2)
validator.validate()
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_unsat(self):
from torch.fx.experimental.validator import ValidationException
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
# Solutions for target is NOT a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
# This expression is less restrictive than its counterpart.
validator.add_target_expr(s1 > s0 + 2)
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
validator.validate()
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_z3str(self):
import z3
from torch.fx.experimental.validator import z3str
a = z3.Int("a")
b = z3.Int("b")
special = z3.Real("this.size()[2]")
test_cases = [
(z3.IntVal(42), "42"),
# Variable.
(a, "a"),
# Name with special characters.
(special, "this.size()[2]"),
# Renamed function fpplications.
(a != b, "(!= a b)"),
(a ** b, "(pow a b)"),
# Chain of associative operations.
*[
(op(op(a, 5), b), f"({opstr} 5 a b)")
for op, opstr in [
(operator.add, "+"),
(operator.mul, "*")
]
],
# Revert 'Not' conversions.
(a != b, "(!= a b)"),
(a < b, "(> b a)"),
(a > b, "(> a b)"),
# Ignore 'ToInt' and 'ToReal' functions.
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
(z3.ToReal(a + b), "(+ a b)"),
# Convert to floor division: 'idiv'.
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
]
for expr, expected in test_cases:
self.assertEqual(z3str(expr), expected)
instantiate_device_type_tests(TestNormalizeOperators, globals())
if __name__ == "__main__":

View File

@ -2012,39 +2012,8 @@ class ShapeEnv:
self.validator.add_assertion(expr)
def _check_translation_validate(self) -> None:
if not _translation_validation_enabled():
return
result = self.validator.validate()
if result.success:
return
if result.model is None:
reason = "no answer"
source_exprs = self.validator._source_exprs
failed = ""
else:
assert result.failed_source_expr is not None
reason = "model: %s" % {sym: result.model[sym] for sym in result.model}
source_exprs = result.failed_source_expr
failed = "Failed "
def exprs_to_str(exprs):
return "\n".join(f"==> {e}" for e in exprs)
assertions = self.validator._assertions
target_exprs = self.validator._target_exprs
raise RuntimeError(f"""translation validation failed with {reason}.
Assertions:
{exprs_to_str(assertions)}
Target Guards:
{exprs_to_str(target_exprs)}
{failed}Source Guards:
{exprs_to_str(source_exprs)}""")
if _translation_validation_enabled():
self.validator.validate()
def create_fx_call_function(
self,

View File

@ -5,9 +5,10 @@ import operator
import sympy
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union
import torch
from torch._dynamo.exc import TorchDynamoException
import torch.fx
import torch.fx.traceback as fx_traceback
@ -42,6 +43,101 @@ try:
# 'ShapeEnv.evaluate_expr' function. Finally, we run the validation.
# (see [Note: TranslationValidator])
# Better Z3 to string implementation (for a small fraction of Z3).
#
# Here are the things we clean before showing the Z3 expression:
# - Rename a few ops (e.g. "Distinct" ==> "!=")
#
# - Ignore ToInt and ToReal operations:
# usually they don't really matter
#
# - Transform (ToInt (/ ...)) into (idiv ...):
# this is the pattern for floor division
#
# - Collect a chain of the same operations into one
def z3str(e: z3.ExprRef) -> str:
assert z3.is_expr(e), f"unsupported expression type: {e}"
def get_args_str(e: z3.ExprRef) -> List[str]:
return [z3str(e.arg(i)) for i in range(e.num_args())]
# First, we simplify the given expression.
# This is done using rewriting rules, so shouldn't take long.
e = z3.simplify(e)
# Only support function applications.
# Even Z3 "variables" are, in fact, function applications.
if not z3.is_app(e):
raise ValueError(f"can't print Z3 expression: {e}")
if z3.is_int_value(e) or z3.is_rational_value(e):
return e.as_string() # type: ignore[attr-defined]
decl = e.decl()
kind = decl.kind()
op = str(decl)
args = get_args_str(e)
if kind == z3.Z3_OP_POWER:
op = "pow"
elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL):
# Collect the arguments of chains of ADD and MUL.
# This is safe, since they are associative.
def collect_str_args(e):
if not (z3.is_app(e) and e.decl().kind() == kind):
return [z3str(e)]
else:
return [
x
for i in range(e.num_args())
for x in collect_str_args(e.arg(i))
]
args = collect_str_args(e)
elif kind == z3.Z3_OP_NOT:
# Revert some conversions that z3.simplify applies:
# - a != b ==> (Not (== a b)) ==> (!= a b)
# - a < b ==> (Not (<= b a)) ==> (> b a)
# - a > b ==> (Not (<= a b)) ==> (> a b)
assert e.num_args() == 1
arg = e.arg(0)
assert z3.is_app(arg)
argkind = arg.decl().kind()
logic_inverse = {
z3.Z3_OP_EQ: "!=",
z3.Z3_OP_LE: ">",
z3.Z3_OP_GE: "<",
}
if argkind in logic_inverse:
op = logic_inverse[argkind]
args = get_args_str(arg)
elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL):
assert e.num_args() == 1
argstr = z3str(e.arg(0))
# Check if it's the floor division pattern.
if argstr.startswith("(/"):
return "(idiv" + argstr[2:]
# Otherwise, just ignore it.
return argstr
elif kind == z3.Z3_OP_UNINTERPRETED:
assert e.num_args() == 0
return str(decl)
string = op + " " + " ".join(args)
return f"({string.rstrip()})"
# Implementation of Python semantics as Z3 expressions.
#
# Z3 Real-Int theory has operators with semantics that differ that of
@ -363,24 +459,13 @@ try:
assert isinstance(ref, z3.BoolRef)
self._assertions.add(ref)
# The result of a validation run.
@dataclass
class Result:
success: bool
# Mapping of the name of each free variable to the value assigned to it.
model: Optional[z3.ModelRef] = None
# List of the source expressions that failed due to the assignment.
failed_source_expr: Optional[List[z3.BoolRef]] = None
def validate(self) -> "TranslationValidator.Result":
def validate(self) -> None:
from torch._dynamo.utils import dynamo_timed
if len(self._source_exprs) == 0 or len(self._target_exprs) == 0:
# If there are no source/target expressions, there's nothing we really
# wish to prove. So, we just return.
return self.Result(success=True)
return None
# Here, we use "QF_NRA" logic for the solver:
# "Quantifier-free Non-linear Real Arithmetic".
@ -411,10 +496,11 @@ try:
# Target expressions are unsound.
# Log the found model and the source expressions that failed.
model = solver.model()
return self.Result(
success=False,
model=model,
failed_source_expr=[inp for inp in self._source_exprs if not model.evaluate(inp)],
raise ValidationException(
model, self._assertions, self._target_exprs,
failed_source_exprs=[
inp for inp in self._source_exprs if not model.evaluate(inp)
]
)
else:
if r == z3.unknown:
@ -426,7 +512,32 @@ try:
# Target expressions are sound.
assert r == z3.unsat
log.debug("translation validation: success")
return self.Result(success=True)
class ValidationException(TorchDynamoException):
def __init__(
self,
model,
assertions: Iterable[z3.ExprRef],
target_exprs: Iterable[z3.ExprRef],
failed_source_exprs: Iterable[z3.ExprRef]
) -> None:
model_str = self._joinlines(model, lambda sym: f"{sym}: {model[sym]}")
assertions_str = self._joinlines(assertions, z3str)
target_exprs_str = self._joinlines(target_exprs, z3str)
failed_source_exprs_str = self._joinlines(failed_source_exprs, z3str)
super().__init__(
"translation validation failed.\n\n"
"Model:\n" + model_str + "\n\n"
"Assertions:\n" + assertions_str + "\n\n"
"Target Expressions:\n" + target_exprs_str + "\n\n"
"Failed Source Expressions:\n" + failed_source_exprs_str + "\n\n"
)
def _joinlines(self, xs: Iterable[Any], f: Callable[[Any], str] = lambda x: x) -> str:
return "\n".join(f" ==> {f(x)}" for x in xs)
except ImportError:
_HAS_Z3 = False
else: