mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Run translation validation on tracing error. (#106645)
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the translation validation (TV) if it ends up raising an error. In this context, we run TV so as to catch simplification errors. These may turn `ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect. For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since it's run only in the end of the tracing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106645 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
937cd3742b
commit
d8ad74857c
@ -1702,171 +1702,165 @@ class TestModule(torch.nn.Module):
|
||||
self.assertIs(kwargs["the_template"], inp2)
|
||||
|
||||
|
||||
class TestTranslationValidator(TestCase):
|
||||
def _prepare_for_translation_validation(self):
|
||||
from torch.fx.experimental.validator import TranslationValidator
|
||||
if TEST_Z3:
|
||||
import z3
|
||||
|
||||
validator = TranslationValidator()
|
||||
import torch._dynamo.config
|
||||
|
||||
# SymPy symbols.
|
||||
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
|
||||
from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod
|
||||
|
||||
# Z3 symbols.
|
||||
[validator.add_var(s, int) for s in (s0, s1, s2)]
|
||||
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
|
||||
class TestTranslationValidation(TestCase):
|
||||
def _prepare_for_translation_validation(self):
|
||||
validator = TranslationValidator()
|
||||
|
||||
return (s0, s1, s2), (z0, z1, z2), validator
|
||||
# SymPy symbols.
|
||||
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_sympy_to_z3_translation(self):
|
||||
import z3
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod
|
||||
from torch.fx.experimental.validator import SympyToZ3
|
||||
# Z3 symbols.
|
||||
[validator.add_var(s, int) for s in (s0, s1, s2)]
|
||||
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
|
||||
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
return (s0, s1, s2), (z0, z1, z2), validator
|
||||
|
||||
def test_sympy_to_z3(self):
|
||||
|
||||
test_cases = [
|
||||
# Integer constants.
|
||||
(sympy.S.Zero, z3.IntVal(0)),
|
||||
(sympy.S.One, z3.IntVal(1)),
|
||||
(sympy.S.NegativeOne, z3.IntVal(-1)),
|
||||
(sympy.Integer(2), z3.IntVal(2)),
|
||||
(
|
||||
s0,
|
||||
z0,
|
||||
),
|
||||
# Arithmetic operations.
|
||||
*[
|
||||
(op(s0, s1), op(z0, z1))
|
||||
for op in (
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.pow,
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
test_cases = [
|
||||
# Integer constants.
|
||||
(sympy.S.Zero, z3.IntVal(0)),
|
||||
(sympy.S.One, z3.IntVal(1)),
|
||||
(sympy.S.NegativeOne, z3.IntVal(-1)),
|
||||
(sympy.Integer(2), z3.IntVal(2)),
|
||||
(
|
||||
s0,
|
||||
z0,
|
||||
),
|
||||
# Arithmetic operations.
|
||||
*[
|
||||
(op(s0, s1), op(z0, z1))
|
||||
for op in (
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.pow,
|
||||
)
|
||||
],
|
||||
# Logical operations.
|
||||
*[
|
||||
(sympy_op(s0, s1), z3_op(z0, z1))
|
||||
for sympy_op, z3_op in (
|
||||
(sympy.Eq, operator.eq),
|
||||
(sympy.Ne, operator.ne),
|
||||
(sympy.Lt, operator.lt),
|
||||
(sympy.Le, operator.le),
|
||||
(sympy.Gt, operator.gt),
|
||||
(sympy.Ge, operator.ge),
|
||||
)
|
||||
],
|
||||
# Other operations.
|
||||
(
|
||||
s0 - s1,
|
||||
z0 + z3.IntVal(-1) * z1,
|
||||
),
|
||||
(
|
||||
s0 / s1,
|
||||
z3.ToReal(z0) * (z1**-1),
|
||||
),
|
||||
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
|
||||
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
|
||||
(
|
||||
Mod(s2, (s0 / s1)),
|
||||
z2
|
||||
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
|
||||
* (z3.ToReal(z0) * z1**-1),
|
||||
),
|
||||
(
|
||||
Mod(s2, s0**3),
|
||||
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
|
||||
),
|
||||
]
|
||||
|
||||
toZ3 = SympyToZ3(validator)
|
||||
for sympy_expr, z3_expr in test_cases:
|
||||
result = toZ3.run(sympy_expr)
|
||||
self.assertTrue(
|
||||
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
|
||||
)
|
||||
],
|
||||
# Logical operations.
|
||||
*[
|
||||
(sympy_op(s0, s1), z3_op(z0, z1))
|
||||
for sympy_op, z3_op in (
|
||||
(sympy.Eq, operator.eq),
|
||||
(sympy.Ne, operator.ne),
|
||||
(sympy.Lt, operator.lt),
|
||||
(sympy.Le, operator.le),
|
||||
(sympy.Gt, operator.gt),
|
||||
(sympy.Ge, operator.ge),
|
||||
)
|
||||
],
|
||||
# Other operations.
|
||||
|
||||
def test_sat(self):
|
||||
(
|
||||
s0 - s1,
|
||||
z0 + z3.IntVal(-1) * z1,
|
||||
),
|
||||
(
|
||||
s0 / s1,
|
||||
z3.ToReal(z0) * (z1**-1),
|
||||
),
|
||||
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
|
||||
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
|
||||
(
|
||||
Mod(s2, (s0 / s1)),
|
||||
z2
|
||||
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
|
||||
* (z3.ToReal(z0) * z1**-1),
|
||||
),
|
||||
(
|
||||
Mod(s2, s0**3),
|
||||
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
|
||||
),
|
||||
]
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
toZ3 = SympyToZ3(validator)
|
||||
for sympy_expr, z3_expr in test_cases:
|
||||
result = toZ3.run(sympy_expr)
|
||||
self.assertTrue(
|
||||
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
|
||||
)
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_translation_validation_sat(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
# Solutions for target is a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
validator.add_target_expr(s1 > s0**2)
|
||||
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
# Solutions for target is a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
validator.add_target_expr(s1 > s0**2)
|
||||
|
||||
validator.validate()
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_translation_validation_unsat(self):
|
||||
from torch.fx.experimental.validator import ValidationException
|
||||
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
# Solutions for target is NOT a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
# This expression is less restrictive than its counterpart.
|
||||
validator.add_target_expr(s1 > s0 + 2)
|
||||
|
||||
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
|
||||
validator.validate()
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_z3str(self):
|
||||
import z3
|
||||
from torch.fx.experimental.validator import z3str
|
||||
def test_unsat(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
a = z3.Int("a")
|
||||
b = z3.Int("b")
|
||||
special = z3.Real("this.size()[2]")
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
test_cases = [
|
||||
(z3.IntVal(42), "42"),
|
||||
# Variable.
|
||||
(a, "a"),
|
||||
# Name with special characters.
|
||||
(special, "this.size()[2]"),
|
||||
# Renamed function fpplications.
|
||||
(a != b, "(!= a b)"),
|
||||
(a ** b, "(pow a b)"),
|
||||
# Chain of associative operations.
|
||||
*[
|
||||
(op(op(a, 5), b), f"({opstr} 5 a b)")
|
||||
for op, opstr in [
|
||||
(operator.add, "+"),
|
||||
(operator.mul, "*")
|
||||
]
|
||||
],
|
||||
# Revert 'Not' conversions.
|
||||
(a != b, "(!= a b)"),
|
||||
(a < b, "(> b a)"),
|
||||
(a > b, "(> a b)"),
|
||||
# Ignore 'ToInt' and 'ToReal' functions.
|
||||
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
|
||||
(z3.ToReal(a + b), "(+ a b)"),
|
||||
# Convert to floor division: 'idiv'.
|
||||
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
|
||||
]
|
||||
# Solutions for target is NOT a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
# This expression is less restrictive than its counterpart.
|
||||
validator.add_target_expr(s1 > s0 + 2)
|
||||
|
||||
for expr, expected in test_cases:
|
||||
self.assertEqual(z3str(expr), expected)
|
||||
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
|
||||
validator.validate()
|
||||
|
||||
def test_z3str(self):
|
||||
a = z3.Int("a")
|
||||
b = z3.Int("b")
|
||||
special = z3.Real("this.size()[2]")
|
||||
|
||||
test_cases = [
|
||||
(z3.IntVal(42), "42"),
|
||||
# Variable.
|
||||
(a, "a"),
|
||||
# Name with special characters.
|
||||
(special, "this.size()[2]"),
|
||||
# Renamed function fpplications.
|
||||
(a != b, "(!= a b)"),
|
||||
(a ** b, "(pow a b)"),
|
||||
# Chain of associative operations.
|
||||
*[
|
||||
(op(op(a, 5), b), f"({opstr} 5 a b)")
|
||||
for op, opstr in [
|
||||
(operator.add, "+"),
|
||||
(operator.mul, "*")
|
||||
]
|
||||
],
|
||||
# Revert 'Not' conversions.
|
||||
(a != b, "(!= a b)"),
|
||||
(a < b, "(> b a)"),
|
||||
(a > b, "(> a b)"),
|
||||
# Ignore 'ToInt' and 'ToReal' functions.
|
||||
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
|
||||
(z3.ToReal(a + b), "(+ a b)"),
|
||||
# Convert to floor division: 'idiv'.
|
||||
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
|
||||
]
|
||||
|
||||
for expr, expected in test_cases:
|
||||
self.assertEqual(z3str(expr), expected)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestNormalizeOperators, globals())
|
||||
|
||||
Reference in New Issue
Block a user