Run translation validation on tracing error. (#106645)

This PR wraps `InstructionTranslator` run with a try-catch block so as to run the
translation validation (TV) if it ends up raising an error.

In this context, we run TV so as to catch simplification errors. These may turn
`ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect.

For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since
it's run only in the end of the tracing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106645
Approved by: https://github.com/ezyang
This commit is contained in:
Yukio Siraichi
2023-08-12 18:07:55 -03:00
committed by PyTorch MergeBot
parent 937cd3742b
commit d8ad74857c
6 changed files with 365 additions and 271 deletions

View File

@ -1702,171 +1702,165 @@ class TestModule(torch.nn.Module):
self.assertIs(kwargs["the_template"], inp2)
class TestTranslationValidator(TestCase):
def _prepare_for_translation_validation(self):
from torch.fx.experimental.validator import TranslationValidator
if TEST_Z3:
import z3
validator = TranslationValidator()
import torch._dynamo.config
# SymPy symbols.
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
from torch.utils._sympy.functions import FloorDiv, Mod
# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
class TestTranslationValidation(TestCase):
def _prepare_for_translation_validation(self):
validator = TranslationValidator()
return (s0, s1, s2), (z0, z1, z2), validator
# SymPy symbols.
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_sympy_to_z3_translation(self):
import z3
from torch.utils._sympy.functions import FloorDiv, Mod
from torch.fx.experimental.validator import SympyToZ3
# Z3 symbols.
[validator.add_var(s, int) for s in (s0, s1, s2)]
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
return (s0, s1, s2), (z0, z1, z2), validator
def test_sympy_to_z3(self):
test_cases = [
# Integer constants.
(sympy.S.Zero, z3.IntVal(0)),
(sympy.S.One, z3.IntVal(1)),
(sympy.S.NegativeOne, z3.IntVal(-1)),
(sympy.Integer(2), z3.IntVal(2)),
(
s0,
z0,
),
# Arithmetic operations.
*[
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mul,
operator.pow,
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
test_cases = [
# Integer constants.
(sympy.S.Zero, z3.IntVal(0)),
(sympy.S.One, z3.IntVal(1)),
(sympy.S.NegativeOne, z3.IntVal(-1)),
(sympy.Integer(2), z3.IntVal(2)),
(
s0,
z0,
),
# Arithmetic operations.
*[
(op(s0, s1), op(z0, z1))
for op in (
operator.add,
operator.mul,
operator.pow,
)
],
# Logical operations.
*[
(sympy_op(s0, s1), z3_op(z0, z1))
for sympy_op, z3_op in (
(sympy.Eq, operator.eq),
(sympy.Ne, operator.ne),
(sympy.Lt, operator.lt),
(sympy.Le, operator.le),
(sympy.Gt, operator.gt),
(sympy.Ge, operator.ge),
)
],
# Other operations.
(
s0 - s1,
z0 + z3.IntVal(-1) * z1,
),
(
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
Mod(s2, (s0 / s1)),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
Mod(s2, s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]
toZ3 = SympyToZ3(validator)
for sympy_expr, z3_expr in test_cases:
result = toZ3.run(sympy_expr)
self.assertTrue(
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
)
],
# Logical operations.
*[
(sympy_op(s0, s1), z3_op(z0, z1))
for sympy_op, z3_op in (
(sympy.Eq, operator.eq),
(sympy.Ne, operator.ne),
(sympy.Lt, operator.lt),
(sympy.Le, operator.le),
(sympy.Gt, operator.gt),
(sympy.Ge, operator.ge),
)
],
# Other operations.
def test_sat(self):
(
s0 - s1,
z0 + z3.IntVal(-1) * z1,
),
(
s0 / s1,
z3.ToReal(z0) * (z1**-1),
),
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
(
Mod(s2, (s0 / s1)),
z2
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
* (z3.ToReal(z0) * z1**-1),
),
(
Mod(s2, s0**3),
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
),
]
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
toZ3 = SympyToZ3(validator)
for sympy_expr, z3_expr in test_cases:
result = toZ3.run(sympy_expr)
self.assertTrue(
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
)
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_sat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
# Solutions for target is a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
validator.add_target_expr(s1 > s0**2)
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
# Solutions for target is a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
validator.add_target_expr(s1 > s0**2)
validator.validate()
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_translation_validation_unsat(self):
from torch.fx.experimental.validator import ValidationException
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
# Solutions for target is NOT a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
# This expression is less restrictive than its counterpart.
validator.add_target_expr(s1 > s0 + 2)
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
validator.validate()
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
def test_z3str(self):
import z3
from torch.fx.experimental.validator import z3str
def test_unsat(self):
(
(s0, s1, s2),
(z0, z1, z2),
validator,
) = self._prepare_for_translation_validation()
a = z3.Int("a")
b = z3.Int("b")
special = z3.Real("this.size()[2]")
validator.add_source_expr(z0 > 5)
validator.add_source_expr(z1 / 2 > z0)
test_cases = [
(z3.IntVal(42), "42"),
# Variable.
(a, "a"),
# Name with special characters.
(special, "this.size()[2]"),
# Renamed function fpplications.
(a != b, "(!= a b)"),
(a ** b, "(pow a b)"),
# Chain of associative operations.
*[
(op(op(a, 5), b), f"({opstr} 5 a b)")
for op, opstr in [
(operator.add, "+"),
(operator.mul, "*")
]
],
# Revert 'Not' conversions.
(a != b, "(!= a b)"),
(a < b, "(> b a)"),
(a > b, "(> a b)"),
# Ignore 'ToInt' and 'ToReal' functions.
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
(z3.ToReal(a + b), "(+ a b)"),
# Convert to floor division: 'idiv'.
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
]
# Solutions for target is NOT a subset of the solutions for the source.
validator.add_target_expr(s0 > 20)
# This expression is less restrictive than its counterpart.
validator.add_target_expr(s1 > s0 + 2)
for expr, expected in test_cases:
self.assertEqual(z3str(expr), expected)
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
validator.validate()
def test_z3str(self):
a = z3.Int("a")
b = z3.Int("b")
special = z3.Real("this.size()[2]")
test_cases = [
(z3.IntVal(42), "42"),
# Variable.
(a, "a"),
# Name with special characters.
(special, "this.size()[2]"),
# Renamed function fpplications.
(a != b, "(!= a b)"),
(a ** b, "(pow a b)"),
# Chain of associative operations.
*[
(op(op(a, 5), b), f"({opstr} 5 a b)")
for op, opstr in [
(operator.add, "+"),
(operator.mul, "*")
]
],
# Revert 'Not' conversions.
(a != b, "(!= a b)"),
(a < b, "(> b a)"),
(a > b, "(> a b)"),
# Ignore 'ToInt' and 'ToReal' functions.
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
(z3.ToReal(a + b), "(+ a b)"),
# Convert to floor division: 'idiv'.
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
]
for expr, expected in test_cases:
self.assertEqual(z3str(expr), expected)
instantiate_device_type_tests(TestNormalizeOperators, globals())