diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 10e607796016..5b1a7a447545 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -44,11 +44,7 @@ from torch.ao.quantization.fake_quantize import FakeQuantize from torch.ao.quantization.qconfig import QConfig from torch.ao.quantization.quantize_fx import prepare_qat_fx from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler -from torch.fx.experimental.symbolic_shapes import ( - ConstraintViolationError, - FloorDiv, - Mod, -) +from torch.fx.experimental.symbolic_shapes import ConstraintViolationError from torch.nn import functional as F from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FUSED_SDPA, @@ -56,7 +52,7 @@ from torch.testing._internal.common_cuda import ( TEST_CUDA, TEST_MULTIGPU, ) -from torch.testing._internal.common_utils import freeze_rng_state, IS_FBCODE, TEST_Z3 +from torch.testing._internal.common_utils import freeze_rng_state, IS_FBCODE from torch.testing._internal.jit_utils import JitTestCase mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"]) @@ -6226,133 +6222,6 @@ def ___make_guard_fn(): self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 9) - def _prepare_for_translation_validation(self): - from torch.fx.experimental.validator import TranslationValidator - - validator = TranslationValidator() - - # SymPy symbols. - s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True) - - # Z3 symbols. - [validator.add_var(s, int) for s in (s0, s1, s2)] - z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2)) - - return (s0, s1, s2), (z0, z1, z2), validator - - @unittest.skipIf(not TEST_Z3, "Z3 not installed") - def test_sympy_to_z3_translation(self): - import z3 - from torch.fx.experimental.validator import SympyToZ3 - - ( - (s0, s1, s2), - (z0, z1, z2), - validator, - ) = self._prepare_for_translation_validation() - - test_cases = [ - # Integer constants. - (sympy.S.Zero, z3.IntVal(0)), - (sympy.S.One, z3.IntVal(1)), - (sympy.S.NegativeOne, z3.IntVal(-1)), - (sympy.Integer(2), z3.IntVal(2)), - ( - s0, - z0, - ), - # Arithmetic operations. - *[ - (op(s0, s1), op(z0, z1)) - for op in ( - operator.add, - operator.mul, - operator.pow, - ) - ], - # Logical operations. - *[ - (sympy_op(s0, s1), z3_op(z0, z1)) - for sympy_op, z3_op in ( - (sympy.Eq, operator.eq), - (sympy.Ne, operator.ne), - (sympy.Lt, operator.lt), - (sympy.Le, operator.le), - (sympy.Gt, operator.gt), - (sympy.Ge, operator.ge), - ) - ], - # Other operations. - ( - s0 - s1, - z0 + z3.IntVal(-1) * z1, - ), - ( - s0 / s1, - z3.ToReal(z0) * (z1**-1), - ), - (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))), - (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1), - ( - Mod(s2, (s0 / s1)), - z2 - - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1))) - * (z3.ToReal(z0) * z1**-1), - ), - ( - Mod(s2, s0**3), - z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3, - ), - ] - - toZ3 = SympyToZ3(validator) - for sympy_expr, z3_expr in test_cases: - result = toZ3.run(sympy_expr) - self.assertTrue( - z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}" - ) - - @unittest.skipIf(not TEST_Z3, "Z3 not installed") - def test_translation_validation_sat(self): - ( - (s0, s1, s2), - (z0, z1, z2), - validator, - ) = self._prepare_for_translation_validation() - - validator.add_source_expr(z0 > 5) - validator.add_source_expr(z1 / 2 > z0) - - # Solutions for target is a subset of the solutions for the source. - validator.add_target_expr(s0 > 20) - validator.add_target_expr(s1 > s0**2) - - r = validator.validate() - self.assertEqual(r.success, True, msg=f"failed with model: {r.model}") - self.assertIsNone(r.model) - self.assertIsNone(r.failed_source_expr) - - @unittest.skipIf(not TEST_Z3, "Z3 not installed") - def test_translation_validation_unsat(self): - ( - (s0, s1, s2), - (z0, z1, z2), - validator, - ) = self._prepare_for_translation_validation() - - validator.add_source_expr(z0 > 5) - validator.add_source_expr(z1 / 2 > z0) - - # Solutions for target is NOT a subset of the solutions for the source. - validator.add_target_expr(s0 > 20) - # This expression is less restrictive than its counterpart. - validator.add_target_expr(s1 > s0 + 2) - - r = validator.validate() - self.assertEqual(r.success, False, msg=f"failed with model: {r.model}") - self.assertIsNotNone(r.model) - self.assertIsNotNone(r.failed_source_expr) - def test_simple_set_usage(self): def foo(x, y): setty = {x, y} diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index e8ca74241d59..2e8f95d2f2ee 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -5,6 +5,7 @@ import numbers import operator import pickle import sys +import sympy import tempfile import unittest from types import BuiltinFunctionType @@ -48,7 +49,7 @@ from torch.testing._internal.common_device_type import ( ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_nn import module_tests, new_module_tests -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase from torch.testing._internal.jit_utils import JitTestCase try: @@ -1700,6 +1701,174 @@ class TestModule(torch.nn.Module): self.assertIs(kwargs["input"], inp1) self.assertIs(kwargs["the_template"], inp2) + +class TestTranslationValidator(TestCase): + def _prepare_for_translation_validation(self): + from torch.fx.experimental.validator import TranslationValidator + + validator = TranslationValidator() + + # SymPy symbols. + s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True) + + # Z3 symbols. + [validator.add_var(s, int) for s in (s0, s1, s2)] + z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2)) + + return (s0, s1, s2), (z0, z1, z2), validator + + @unittest.skipIf(not TEST_Z3, "Z3 not installed") + def test_sympy_to_z3_translation(self): + import z3 + from torch.utils._sympy.functions import FloorDiv, Mod + from torch.fx.experimental.validator import SympyToZ3 + + ( + (s0, s1, s2), + (z0, z1, z2), + validator, + ) = self._prepare_for_translation_validation() + + test_cases = [ + # Integer constants. + (sympy.S.Zero, z3.IntVal(0)), + (sympy.S.One, z3.IntVal(1)), + (sympy.S.NegativeOne, z3.IntVal(-1)), + (sympy.Integer(2), z3.IntVal(2)), + ( + s0, + z0, + ), + # Arithmetic operations. + *[ + (op(s0, s1), op(z0, z1)) + for op in ( + operator.add, + operator.mul, + operator.pow, + ) + ], + # Logical operations. + *[ + (sympy_op(s0, s1), z3_op(z0, z1)) + for sympy_op, z3_op in ( + (sympy.Eq, operator.eq), + (sympy.Ne, operator.ne), + (sympy.Lt, operator.lt), + (sympy.Le, operator.le), + (sympy.Gt, operator.gt), + (sympy.Ge, operator.ge), + ) + ], + # Other operations. + ( + s0 - s1, + z0 + z3.IntVal(-1) * z1, + ), + ( + s0 / s1, + z3.ToReal(z0) * (z1**-1), + ), + (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))), + (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1), + ( + Mod(s2, (s0 / s1)), + z2 + - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1))) + * (z3.ToReal(z0) * z1**-1), + ), + ( + Mod(s2, s0**3), + z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3, + ), + ] + + toZ3 = SympyToZ3(validator) + for sympy_expr, z3_expr in test_cases: + result = toZ3.run(sympy_expr) + self.assertTrue( + z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}" + ) + + @unittest.skipIf(not TEST_Z3, "Z3 not installed") + def test_translation_validation_sat(self): + ( + (s0, s1, s2), + (z0, z1, z2), + validator, + ) = self._prepare_for_translation_validation() + + validator.add_source_expr(z0 > 5) + validator.add_source_expr(z1 / 2 > z0) + + # Solutions for target is a subset of the solutions for the source. + validator.add_target_expr(s0 > 20) + validator.add_target_expr(s1 > s0**2) + + validator.validate() + + @unittest.skipIf(not TEST_Z3, "Z3 not installed") + def test_translation_validation_unsat(self): + from torch.fx.experimental.validator import ValidationException + + ( + (s0, s1, s2), + (z0, z1, z2), + validator, + ) = self._prepare_for_translation_validation() + + validator.add_source_expr(z0 > 5) + validator.add_source_expr(z1 / 2 > z0) + + # Solutions for target is NOT a subset of the solutions for the source. + validator.add_target_expr(s0 > 20) + # This expression is less restrictive than its counterpart. + validator.add_target_expr(s1 > s0 + 2) + + with self.assertRaisesRegex(ValidationException, "translation validation failed."): + validator.validate() + + @unittest.skipIf(not TEST_Z3, "Z3 not installed") + def test_z3str(self): + import z3 + from torch.fx.experimental.validator import z3str + + a = z3.Int("a") + b = z3.Int("b") + special = z3.Real("this.size()[2]") + + test_cases = [ + (z3.IntVal(42), "42"), + # Variable. + (a, "a"), + # Name with special characters. + (special, "this.size()[2]"), + # Renamed function fpplications. + (a != b, "(!= a b)"), + (a ** b, "(pow a b)"), + # Chain of associative operations. + *[ + (op(op(a, 5), b), f"({opstr} 5 a b)") + for op, opstr in [ + (operator.add, "+"), + (operator.mul, "*") + ] + ], + # Revert 'Not' conversions. + (a != b, "(!= a b)"), + (a < b, "(> b a)"), + (a > b, "(> a b)"), + # Ignore 'ToInt' and 'ToReal' functions. + (z3.ToInt(special) + a, "(+ this.size()[2] a)"), + (z3.ToReal(a + b), "(+ a b)"), + # Convert to floor division: 'idiv'. + (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"), + ] + + for expr, expected in test_cases: + self.assertEqual(z3str(expr), expected) + + instantiate_device_type_tests(TestNormalizeOperators, globals()) if __name__ == "__main__": diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 3d5e65f0180e..55f1e64a0b7f 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2012,39 +2012,8 @@ class ShapeEnv: self.validator.add_assertion(expr) def _check_translation_validate(self) -> None: - if not _translation_validation_enabled(): - return - - result = self.validator.validate() - - if result.success: - return - - if result.model is None: - reason = "no answer" - source_exprs = self.validator._source_exprs - failed = "" - else: - assert result.failed_source_expr is not None - reason = "model: %s" % {sym: result.model[sym] for sym in result.model} - source_exprs = result.failed_source_expr - failed = "Failed " - - def exprs_to_str(exprs): - return "\n".join(f"==> {e}" for e in exprs) - - assertions = self.validator._assertions - target_exprs = self.validator._target_exprs - - raise RuntimeError(f"""translation validation failed with {reason}. -Assertions: -{exprs_to_str(assertions)} - -Target Guards: -{exprs_to_str(target_exprs)} - -{failed}Source Guards: -{exprs_to_str(source_exprs)}""") + if _translation_validation_enabled(): + self.validator.validate() def create_fx_call_function( self, diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index f05f71e14e3c..df993fbba2a2 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -5,9 +5,10 @@ import operator import sympy from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union import torch +from torch._dynamo.exc import TorchDynamoException import torch.fx import torch.fx.traceback as fx_traceback @@ -42,6 +43,101 @@ try: # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation. # (see [Note: TranslationValidator]) + # Better Z3 to string implementation (for a small fraction of Z3). + # + # Here are the things we clean before showing the Z3 expression: + # - Rename a few ops (e.g. "Distinct" ==> "!=") + # + # - Ignore ToInt and ToReal operations: + # usually they don't really matter + # + # - Transform (ToInt (/ ...)) into (idiv ...): + # this is the pattern for floor division + # + # - Collect a chain of the same operations into one + def z3str(e: z3.ExprRef) -> str: + assert z3.is_expr(e), f"unsupported expression type: {e}" + + def get_args_str(e: z3.ExprRef) -> List[str]: + return [z3str(e.arg(i)) for i in range(e.num_args())] + + # First, we simplify the given expression. + # This is done using rewriting rules, so shouldn't take long. + e = z3.simplify(e) + + + # Only support function applications. + # Even Z3 "variables" are, in fact, function applications. + if not z3.is_app(e): + raise ValueError(f"can't print Z3 expression: {e}") + + if z3.is_int_value(e) or z3.is_rational_value(e): + return e.as_string() # type: ignore[attr-defined] + + decl = e.decl() + kind = decl.kind() + op = str(decl) + args = get_args_str(e) + + if kind == z3.Z3_OP_POWER: + op = "pow" + + elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL): + # Collect the arguments of chains of ADD and MUL. + # This is safe, since they are associative. + + def collect_str_args(e): + if not (z3.is_app(e) and e.decl().kind() == kind): + return [z3str(e)] + else: + return [ + x + for i in range(e.num_args()) + for x in collect_str_args(e.arg(i)) + ] + + args = collect_str_args(e) + + elif kind == z3.Z3_OP_NOT: + # Revert some conversions that z3.simplify applies: + # - a != b ==> (Not (== a b)) ==> (!= a b) + # - a < b ==> (Not (<= b a)) ==> (> b a) + # - a > b ==> (Not (<= a b)) ==> (> a b) + + assert e.num_args() == 1 + arg = e.arg(0) + + assert z3.is_app(arg) + argkind = arg.decl().kind() + + logic_inverse = { + z3.Z3_OP_EQ: "!=", + z3.Z3_OP_LE: ">", + z3.Z3_OP_GE: "<", + } + + if argkind in logic_inverse: + op = logic_inverse[argkind] + args = get_args_str(arg) + + elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL): + assert e.num_args() == 1 + argstr = z3str(e.arg(0)) + + # Check if it's the floor division pattern. + if argstr.startswith("(/"): + return "(idiv" + argstr[2:] + + # Otherwise, just ignore it. + return argstr + + elif kind == z3.Z3_OP_UNINTERPRETED: + assert e.num_args() == 0 + return str(decl) + + string = op + " " + " ".join(args) + return f"({string.rstrip()})" + # Implementation of Python semantics as Z3 expressions. # # Z3 Real-Int theory has operators with semantics that differ that of @@ -363,24 +459,13 @@ try: assert isinstance(ref, z3.BoolRef) self._assertions.add(ref) - # The result of a validation run. - @dataclass - class Result: - success: bool - - # Mapping of the name of each free variable to the value assigned to it. - model: Optional[z3.ModelRef] = None - - # List of the source expressions that failed due to the assignment. - failed_source_expr: Optional[List[z3.BoolRef]] = None - - def validate(self) -> "TranslationValidator.Result": + def validate(self) -> None: from torch._dynamo.utils import dynamo_timed if len(self._source_exprs) == 0 or len(self._target_exprs) == 0: # If there are no source/target expressions, there's nothing we really # wish to prove. So, we just return. - return self.Result(success=True) + return None # Here, we use "QF_NRA" logic for the solver: # "Quantifier-free Non-linear Real Arithmetic". @@ -411,10 +496,11 @@ try: # Target expressions are unsound. # Log the found model and the source expressions that failed. model = solver.model() - return self.Result( - success=False, - model=model, - failed_source_expr=[inp for inp in self._source_exprs if not model.evaluate(inp)], + raise ValidationException( + model, self._assertions, self._target_exprs, + failed_source_exprs=[ + inp for inp in self._source_exprs if not model.evaluate(inp) + ] ) else: if r == z3.unknown: @@ -426,7 +512,32 @@ try: # Target expressions are sound. assert r == z3.unsat log.debug("translation validation: success") - return self.Result(success=True) + + + class ValidationException(TorchDynamoException): + def __init__( + self, + model, + assertions: Iterable[z3.ExprRef], + target_exprs: Iterable[z3.ExprRef], + failed_source_exprs: Iterable[z3.ExprRef] + ) -> None: + model_str = self._joinlines(model, lambda sym: f"{sym}: {model[sym]}") + assertions_str = self._joinlines(assertions, z3str) + target_exprs_str = self._joinlines(target_exprs, z3str) + failed_source_exprs_str = self._joinlines(failed_source_exprs, z3str) + + super().__init__( + "translation validation failed.\n\n" + "Model:\n" + model_str + "\n\n" + "Assertions:\n" + assertions_str + "\n\n" + "Target Expressions:\n" + target_exprs_str + "\n\n" + "Failed Source Expressions:\n" + failed_source_exprs_str + "\n\n" + ) + + def _joinlines(self, xs: Iterable[Any], f: Callable[[Any], str] = lambda x: x) -> str: + return "\n".join(f" ==> {f(x)}" for x in xs) + except ImportError: _HAS_Z3 = False else: