mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Run translation validation on tracing error. (#106645)
This PR wraps `InstructionTranslator` run with a try-catch block so as to run the translation validation (TV) if it ends up raising an error. In this context, we run TV so as to catch simplification errors. These may turn `ShapeEnv.divisible` and `ShapeEnv.replacements` incorrect. For example: #101173 describes a SymPy simplification bug that doesn't reach TV, since it's run only in the end of the tracing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106645 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
937cd3742b
commit
d8ad74857c
@ -9,7 +9,8 @@ import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
from torch._dynamo.comptime import comptime
|
||||
from torch._dynamo.exc import Unsupported
|
||||
from torch.testing._internal.common_utils import munge_exc
|
||||
from torch.testing._internal.common_device_type import skipIf
|
||||
from torch.testing._internal.common_utils import munge_exc, TEST_Z3
|
||||
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
||||
|
||||
|
||||
@ -189,6 +190,53 @@ backend='relu_compile_error_TESTING_ONLY' raised:
|
||||
ReluCompileError:""",
|
||||
)
|
||||
|
||||
@skipIf(not TEST_Z3, "z3 not installed")
|
||||
@torch._dynamo.config.patch(
|
||||
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY=True,
|
||||
assume_static_by_default=False,
|
||||
translation_validation=True,
|
||||
suppress_errors=False,
|
||||
)
|
||||
def test_trigger_on_error(self):
|
||||
from torch.fx.experimental.validator import ValidationException
|
||||
|
||||
@torch.compile
|
||||
def fn(x):
|
||||
return x.reshape(-1, 4)
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
ValidationException,
|
||||
lambda: fn(torch.randn(20)),
|
||||
"""\
|
||||
translation validation failed.
|
||||
|
||||
Model:
|
||||
==> L['x'].storage_offset(): 0
|
||||
==> s0: 4
|
||||
==> L['x'].stride()[0]: 1
|
||||
==> L['x'].size()[0]: 4
|
||||
|
||||
Assertions:
|
||||
==> (== L['x'].size()[0] s0)
|
||||
==> (> s0 1)
|
||||
==> (Not (And (< L['x'].size()[0] 4) (>= L['x'].size()[0] 0)))
|
||||
==> (== 0 L['x'].storage_offset())
|
||||
==> (== 1 L['x'].stride()[0])
|
||||
==> (True)
|
||||
|
||||
Target Expressions:
|
||||
==> (>= 9223372036854775806 s0)
|
||||
==> (== 4 L['x'].size()[0])
|
||||
==> (== 0 L['x'].storage_offset())
|
||||
==> (> s0 0)
|
||||
==> (== 1 L['x'].stride()[0])
|
||||
==> (<= 2 s0)
|
||||
==> (== 4 s0)
|
||||
|
||||
Failed Source Expressions:
|
||||
==> (!= 4 L['x'].size()[0])""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -1702,171 +1702,165 @@ class TestModule(torch.nn.Module):
|
||||
self.assertIs(kwargs["the_template"], inp2)
|
||||
|
||||
|
||||
class TestTranslationValidator(TestCase):
|
||||
def _prepare_for_translation_validation(self):
|
||||
from torch.fx.experimental.validator import TranslationValidator
|
||||
if TEST_Z3:
|
||||
import z3
|
||||
|
||||
validator = TranslationValidator()
|
||||
import torch._dynamo.config
|
||||
|
||||
# SymPy symbols.
|
||||
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
|
||||
from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod
|
||||
|
||||
# Z3 symbols.
|
||||
[validator.add_var(s, int) for s in (s0, s1, s2)]
|
||||
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
|
||||
class TestTranslationValidation(TestCase):
|
||||
def _prepare_for_translation_validation(self):
|
||||
validator = TranslationValidator()
|
||||
|
||||
return (s0, s1, s2), (z0, z1, z2), validator
|
||||
# SymPy symbols.
|
||||
s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_sympy_to_z3_translation(self):
|
||||
import z3
|
||||
from torch.utils._sympy.functions import FloorDiv, Mod
|
||||
from torch.fx.experimental.validator import SympyToZ3
|
||||
# Z3 symbols.
|
||||
[validator.add_var(s, int) for s in (s0, s1, s2)]
|
||||
z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))
|
||||
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
return (s0, s1, s2), (z0, z1, z2), validator
|
||||
|
||||
def test_sympy_to_z3(self):
|
||||
|
||||
test_cases = [
|
||||
# Integer constants.
|
||||
(sympy.S.Zero, z3.IntVal(0)),
|
||||
(sympy.S.One, z3.IntVal(1)),
|
||||
(sympy.S.NegativeOne, z3.IntVal(-1)),
|
||||
(sympy.Integer(2), z3.IntVal(2)),
|
||||
(
|
||||
s0,
|
||||
z0,
|
||||
),
|
||||
# Arithmetic operations.
|
||||
*[
|
||||
(op(s0, s1), op(z0, z1))
|
||||
for op in (
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.pow,
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
test_cases = [
|
||||
# Integer constants.
|
||||
(sympy.S.Zero, z3.IntVal(0)),
|
||||
(sympy.S.One, z3.IntVal(1)),
|
||||
(sympy.S.NegativeOne, z3.IntVal(-1)),
|
||||
(sympy.Integer(2), z3.IntVal(2)),
|
||||
(
|
||||
s0,
|
||||
z0,
|
||||
),
|
||||
# Arithmetic operations.
|
||||
*[
|
||||
(op(s0, s1), op(z0, z1))
|
||||
for op in (
|
||||
operator.add,
|
||||
operator.mul,
|
||||
operator.pow,
|
||||
)
|
||||
],
|
||||
# Logical operations.
|
||||
*[
|
||||
(sympy_op(s0, s1), z3_op(z0, z1))
|
||||
for sympy_op, z3_op in (
|
||||
(sympy.Eq, operator.eq),
|
||||
(sympy.Ne, operator.ne),
|
||||
(sympy.Lt, operator.lt),
|
||||
(sympy.Le, operator.le),
|
||||
(sympy.Gt, operator.gt),
|
||||
(sympy.Ge, operator.ge),
|
||||
)
|
||||
],
|
||||
# Other operations.
|
||||
(
|
||||
s0 - s1,
|
||||
z0 + z3.IntVal(-1) * z1,
|
||||
),
|
||||
(
|
||||
s0 / s1,
|
||||
z3.ToReal(z0) * (z1**-1),
|
||||
),
|
||||
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
|
||||
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
|
||||
(
|
||||
Mod(s2, (s0 / s1)),
|
||||
z2
|
||||
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
|
||||
* (z3.ToReal(z0) * z1**-1),
|
||||
),
|
||||
(
|
||||
Mod(s2, s0**3),
|
||||
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
|
||||
),
|
||||
]
|
||||
|
||||
toZ3 = SympyToZ3(validator)
|
||||
for sympy_expr, z3_expr in test_cases:
|
||||
result = toZ3.run(sympy_expr)
|
||||
self.assertTrue(
|
||||
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
|
||||
)
|
||||
],
|
||||
# Logical operations.
|
||||
*[
|
||||
(sympy_op(s0, s1), z3_op(z0, z1))
|
||||
for sympy_op, z3_op in (
|
||||
(sympy.Eq, operator.eq),
|
||||
(sympy.Ne, operator.ne),
|
||||
(sympy.Lt, operator.lt),
|
||||
(sympy.Le, operator.le),
|
||||
(sympy.Gt, operator.gt),
|
||||
(sympy.Ge, operator.ge),
|
||||
)
|
||||
],
|
||||
# Other operations.
|
||||
|
||||
def test_sat(self):
|
||||
(
|
||||
s0 - s1,
|
||||
z0 + z3.IntVal(-1) * z1,
|
||||
),
|
||||
(
|
||||
s0 / s1,
|
||||
z3.ToReal(z0) * (z1**-1),
|
||||
),
|
||||
(FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
|
||||
(Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
|
||||
(
|
||||
Mod(s2, (s0 / s1)),
|
||||
z2
|
||||
- z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
|
||||
* (z3.ToReal(z0) * z1**-1),
|
||||
),
|
||||
(
|
||||
Mod(s2, s0**3),
|
||||
z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
|
||||
),
|
||||
]
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
toZ3 = SympyToZ3(validator)
|
||||
for sympy_expr, z3_expr in test_cases:
|
||||
result = toZ3.run(sympy_expr)
|
||||
self.assertTrue(
|
||||
z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
|
||||
)
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_translation_validation_sat(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
# Solutions for target is a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
validator.add_target_expr(s1 > s0**2)
|
||||
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
# Solutions for target is a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
validator.add_target_expr(s1 > s0**2)
|
||||
|
||||
validator.validate()
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_translation_validation_unsat(self):
|
||||
from torch.fx.experimental.validator import ValidationException
|
||||
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
# Solutions for target is NOT a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
# This expression is less restrictive than its counterpart.
|
||||
validator.add_target_expr(s1 > s0 + 2)
|
||||
|
||||
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
|
||||
validator.validate()
|
||||
|
||||
@unittest.skipIf(not TEST_Z3, "Z3 not installed")
|
||||
def test_z3str(self):
|
||||
import z3
|
||||
from torch.fx.experimental.validator import z3str
|
||||
def test_unsat(self):
|
||||
(
|
||||
(s0, s1, s2),
|
||||
(z0, z1, z2),
|
||||
validator,
|
||||
) = self._prepare_for_translation_validation()
|
||||
|
||||
a = z3.Int("a")
|
||||
b = z3.Int("b")
|
||||
special = z3.Real("this.size()[2]")
|
||||
validator.add_source_expr(z0 > 5)
|
||||
validator.add_source_expr(z1 / 2 > z0)
|
||||
|
||||
test_cases = [
|
||||
(z3.IntVal(42), "42"),
|
||||
# Variable.
|
||||
(a, "a"),
|
||||
# Name with special characters.
|
||||
(special, "this.size()[2]"),
|
||||
# Renamed function fpplications.
|
||||
(a != b, "(!= a b)"),
|
||||
(a ** b, "(pow a b)"),
|
||||
# Chain of associative operations.
|
||||
*[
|
||||
(op(op(a, 5), b), f"({opstr} 5 a b)")
|
||||
for op, opstr in [
|
||||
(operator.add, "+"),
|
||||
(operator.mul, "*")
|
||||
]
|
||||
],
|
||||
# Revert 'Not' conversions.
|
||||
(a != b, "(!= a b)"),
|
||||
(a < b, "(> b a)"),
|
||||
(a > b, "(> a b)"),
|
||||
# Ignore 'ToInt' and 'ToReal' functions.
|
||||
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
|
||||
(z3.ToReal(a + b), "(+ a b)"),
|
||||
# Convert to floor division: 'idiv'.
|
||||
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
|
||||
]
|
||||
# Solutions for target is NOT a subset of the solutions for the source.
|
||||
validator.add_target_expr(s0 > 20)
|
||||
# This expression is less restrictive than its counterpart.
|
||||
validator.add_target_expr(s1 > s0 + 2)
|
||||
|
||||
for expr, expected in test_cases:
|
||||
self.assertEqual(z3str(expr), expected)
|
||||
with self.assertRaisesRegex(ValidationException, "translation validation failed."):
|
||||
validator.validate()
|
||||
|
||||
def test_z3str(self):
|
||||
a = z3.Int("a")
|
||||
b = z3.Int("b")
|
||||
special = z3.Real("this.size()[2]")
|
||||
|
||||
test_cases = [
|
||||
(z3.IntVal(42), "42"),
|
||||
# Variable.
|
||||
(a, "a"),
|
||||
# Name with special characters.
|
||||
(special, "this.size()[2]"),
|
||||
# Renamed function fpplications.
|
||||
(a != b, "(!= a b)"),
|
||||
(a ** b, "(pow a b)"),
|
||||
# Chain of associative operations.
|
||||
*[
|
||||
(op(op(a, 5), b), f"({opstr} 5 a b)")
|
||||
for op, opstr in [
|
||||
(operator.add, "+"),
|
||||
(operator.mul, "*")
|
||||
]
|
||||
],
|
||||
# Revert 'Not' conversions.
|
||||
(a != b, "(!= a b)"),
|
||||
(a < b, "(> b a)"),
|
||||
(a > b, "(> a b)"),
|
||||
# Ignore 'ToInt' and 'ToReal' functions.
|
||||
(z3.ToInt(special) + a, "(+ this.size()[2] a)"),
|
||||
(z3.ToReal(a + b), "(+ a b)"),
|
||||
# Convert to floor division: 'idiv'.
|
||||
(z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
|
||||
]
|
||||
|
||||
for expr, expected in test_cases:
|
||||
self.assertEqual(z3str(expr), expected)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestNormalizeOperators, globals())
|
||||
|
||||
@ -290,6 +290,11 @@ capture_func_transforms = True
|
||||
# used for testing
|
||||
inject_BUILD_SET_unimplemented_TESTING_ONLY = False
|
||||
|
||||
# wraps (un)equalities with 'Not' class after recording the correct expression
|
||||
# in the FX graph. This should incorrectly construct the divisible and replacement
|
||||
# lists, and incorrectly issue guards.
|
||||
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False
|
||||
|
||||
_autograd_backward_strict_mode_banned_ops = [
|
||||
"stride",
|
||||
"requires_grad",
|
||||
|
||||
@ -400,6 +400,11 @@ def _compile(
|
||||
frame: Optional[types.FrameType] = None,
|
||||
frame_state=None,
|
||||
) -> Optional[GuardedCode]:
|
||||
from torch.fx.experimental.validator import (
|
||||
translation_validation_enabled,
|
||||
ValidationException,
|
||||
)
|
||||
|
||||
output: Optional[OutputGraph] = None
|
||||
# This is shared across restarts
|
||||
mutated_closure_cell_contents: Set[str] = set()
|
||||
@ -421,8 +426,21 @@ def _compile(
|
||||
mutated_closure_cell_contents,
|
||||
frame_state=frame_state,
|
||||
)
|
||||
with tracing(tracer.output.tracing_context):
|
||||
tracer.run()
|
||||
|
||||
try:
|
||||
with tracing(tracer.output.tracing_context):
|
||||
tracer.run()
|
||||
except (exc.RestartAnalysis, exc.SkipFrame):
|
||||
raise
|
||||
except Exception:
|
||||
if translation_validation_enabled():
|
||||
fakes = tracer.output.tracked_fakes
|
||||
tracer.output.shape_env.produce_guards(
|
||||
[a.fake for a in fakes],
|
||||
[a.source for a in fakes],
|
||||
)
|
||||
raise
|
||||
|
||||
output = tracer.output
|
||||
assert output is not None
|
||||
assert output.output_instructions
|
||||
@ -542,6 +560,7 @@ def _compile(
|
||||
AssertionError,
|
||||
ConstraintViolationError,
|
||||
GuardOnDataDependentSymNode,
|
||||
ValidationException,
|
||||
) as e:
|
||||
fail_reason = str(e)
|
||||
exception_handler(e, code, frame, export=export)
|
||||
|
||||
@ -2072,6 +2072,10 @@ class ShapeEnv:
|
||||
|
||||
return self.fx_node_cache[node_key]
|
||||
|
||||
def remove_fx_node(self, node: Optional[torch.fx.Node]) -> None:
|
||||
if _translation_validation_enabled() and node is not None:
|
||||
self.graph.erase_node(node)
|
||||
|
||||
def _suppress_guards_tls(self):
|
||||
return getattr(TLS, "suppress_guards", False)
|
||||
|
||||
@ -3168,119 +3172,144 @@ class ShapeEnv:
|
||||
#
|
||||
# If all of the above check, we create an FX node representing the
|
||||
# actual expression to be guarded.
|
||||
node = None
|
||||
if (
|
||||
_translation_validation_enabled()
|
||||
and fx_node is not None
|
||||
and not self._suppress_guards_tls()
|
||||
):
|
||||
if concrete_val is sympy.true:
|
||||
self.create_fx_call_function(torch._assert, (fx_node,))
|
||||
node = self.create_fx_call_function(torch._assert, (fx_node,))
|
||||
elif concrete_val is sympy.false:
|
||||
neg = self.create_fx_call_function(operator.not_, (fx_node,))
|
||||
self.create_fx_call_function(torch._assert, (neg,))
|
||||
node = self.create_fx_call_function(torch._assert, (neg,))
|
||||
else:
|
||||
eql = self.create_fx_call_function(operator.eq, (fx_node, concrete_val))
|
||||
self.create_fx_call_function(torch._assert, (eql,))
|
||||
node = self.create_fx_call_function(torch._assert, (eql,))
|
||||
|
||||
if len(orig_expr.free_symbols) == 0:
|
||||
self.log.debug("eval %s [trivial]", orig_expr)
|
||||
# NB: don't test float as there may be precision issues
|
||||
if isinstance(hint, (int, bool)):
|
||||
assert orig_expr == hint, f"{orig_expr} != {hint}"
|
||||
return orig_expr
|
||||
# After creating the FX node corresponding to orig_expr, we must make sure that
|
||||
# no error will be raised until the end of this function.
|
||||
#
|
||||
# Reason: the translation validation may become invalid otherwise.
|
||||
#
|
||||
# If an error is raised before the end of this function, we remove the FX node
|
||||
# inserted, and re-raise the error.
|
||||
guard = None
|
||||
tb = None
|
||||
|
||||
expr = orig_expr
|
||||
try:
|
||||
if len(orig_expr.free_symbols) == 0:
|
||||
self.log.debug("eval %s [trivial]", orig_expr)
|
||||
# NB: don't test float as there may be precision issues
|
||||
if isinstance(hint, (int, bool)):
|
||||
assert orig_expr == hint, f"{orig_expr} != {hint}"
|
||||
return orig_expr
|
||||
|
||||
static_expr = self._maybe_evaluate_static(expr)
|
||||
if static_expr is not None:
|
||||
self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
|
||||
# NB: don't test float as there may be precision issues
|
||||
if isinstance(hint, (int, bool)):
|
||||
assert static_expr == hint, f"{static_expr} != {hint}"
|
||||
return static_expr
|
||||
expr = orig_expr
|
||||
|
||||
if not (expr.free_symbols <= self.var_to_val.keys()):
|
||||
# TODO: dedupe this with _maybe_evaluate_static
|
||||
# Attempt to eliminate the unbacked SymInt
|
||||
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
||||
if not (new_expr.free_symbols <= self.var_to_val.keys()):
|
||||
raise self._make_data_dependent_error(expr.xreplace(self.var_to_val), expr)
|
||||
expr = new_expr
|
||||
static_expr = self._maybe_evaluate_static(expr)
|
||||
if static_expr is not None:
|
||||
self.log.debug("eval %s == %s [statically known]", orig_expr, static_expr)
|
||||
# NB: don't test float as there may be precision issues
|
||||
if isinstance(hint, (int, bool)):
|
||||
assert static_expr == hint, f"{static_expr} != {hint}"
|
||||
return static_expr
|
||||
|
||||
if self.frozen:
|
||||
self.counter["ignored_backward_guard"] += 1
|
||||
signpost_event(
|
||||
"dynamic",
|
||||
"evaluate_expr_frozen",
|
||||
{
|
||||
**self.co_fields,
|
||||
"ignored_guard": f"{expr} == {concrete_val}",
|
||||
# no version = original state (this signpost is expected)
|
||||
# version 2 = dynamic backwards is eagerly compiled
|
||||
"version": 2,
|
||||
},
|
||||
)
|
||||
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)
|
||||
if not (expr.free_symbols <= self.var_to_val.keys()):
|
||||
# TODO: dedupe this with _maybe_evaluate_static
|
||||
# Attempt to eliminate the unbacked SymInt
|
||||
new_expr = self._maybe_evaluate_static(expr, unbacked_only=True)
|
||||
if not (new_expr.free_symbols <= self.var_to_val.keys()):
|
||||
raise self._make_data_dependent_error(expr.xreplace(self.var_to_val), expr)
|
||||
expr = new_expr
|
||||
|
||||
if isinstance(expr, (sympy.Eq, sympy.Ne)):
|
||||
self._maybe_guard_eq(expr, bool(concrete_val))
|
||||
# TODO: If we successfully eliminate a symbol via equality, it
|
||||
# is not actually necessary to save a guard for the equality,
|
||||
# as we will implicitly generate a guard when we match that
|
||||
# input against the symbol
|
||||
elif isinstance(concrete_val, sympy.Integer):
|
||||
# WARNING: we cannot actually do simplifications on guards
|
||||
# on floating point values, because Sympy generally does not
|
||||
# think expressions on integers can ever be equal to floating
|
||||
# point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False). Without
|
||||
# very clear algebraic laws that hold for floating point, such
|
||||
# simplifications are error prone anyway, so be sure not to
|
||||
# maybe_guard_eq in those cases.
|
||||
self._maybe_guard_eq(sympy.Eq(expr, concrete_val), True)
|
||||
|
||||
if concrete_val is sympy.true:
|
||||
g = expr
|
||||
elif concrete_val is sympy.false:
|
||||
g = sympy.Not(expr)
|
||||
else:
|
||||
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
|
||||
|
||||
if not self._suppress_guards_tls():
|
||||
tb = traceback.extract_stack()[:-1]
|
||||
stack = ''.join(traceback.format_list(tb))
|
||||
guard = ShapeGuard(g, stack)
|
||||
self.guards.append(guard)
|
||||
self.refine_ranges(guard)
|
||||
if self.log.isEnabledFor(logging.INFO):
|
||||
for frame in reversed(tb):
|
||||
if frame.filename not in uninteresting_files():
|
||||
break
|
||||
|
||||
# NB: this stack is truncated, but it's fine because the main
|
||||
# stack_info will give you the rest of the info you need
|
||||
maybe_user_loc = ""
|
||||
user_tb = TracingContext.extract_stack()
|
||||
if user_tb:
|
||||
maybe_user_loc = " at " + format_frame(user_tb[-1])
|
||||
|
||||
is_debug = self.log.isEnabledFor(logging.DEBUG)
|
||||
maybe_extra_debug = ""
|
||||
if is_debug and user_tb:
|
||||
maybe_extra_debug = (
|
||||
'\nUser Stack (most recent call last):\n' +
|
||||
' (snipped, see stack below for prefix)\n' +
|
||||
''.join(traceback.format_list(user_tb))
|
||||
)
|
||||
self.log.info(
|
||||
"eval %s [guard added]%s (%s)%s",
|
||||
g,
|
||||
maybe_user_loc,
|
||||
format_frame(frame),
|
||||
maybe_extra_debug,
|
||||
stack_info=is_debug,
|
||||
if self.frozen:
|
||||
self.counter["ignored_backward_guard"] += 1
|
||||
signpost_event(
|
||||
"dynamic",
|
||||
"evaluate_expr_frozen",
|
||||
{
|
||||
**self.co_fields,
|
||||
"ignored_guard": f"{expr} == {concrete_val}",
|
||||
# no version = original state (this signpost is expected)
|
||||
# version 2 = dynamic backwards is eagerly compiled
|
||||
"version": 2,
|
||||
},
|
||||
)
|
||||
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)
|
||||
|
||||
if torch._dynamo.config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY and isinstance(hint, bool):
|
||||
if isinstance(expr, (sympy.Eq, sympy.Ne)):
|
||||
expr = sympy.Not(expr)
|
||||
|
||||
if isinstance(expr, (sympy.Eq, sympy.Ne)):
|
||||
self._maybe_guard_eq(expr, bool(concrete_val))
|
||||
# TODO: If we successfully eliminate a symbol via equality, it
|
||||
# is not actually necessary to save a guard for the equality,
|
||||
# as we will implicitly generate a guard when we match that
|
||||
# input against the symbol
|
||||
elif isinstance(concrete_val, sympy.Integer):
|
||||
# WARNING: we cannot actually do simplifications on guards
|
||||
# on floating point values, because Sympy generally does not
|
||||
# think expressions on integers can ever be equal to floating
|
||||
# point (e.g., sympy.Eq(s0/6, 0.5) evaluates to False). Without
|
||||
# very clear algebraic laws that hold for floating point, such
|
||||
# simplifications are error prone anyway, so be sure not to
|
||||
# maybe_guard_eq in those cases.
|
||||
self._maybe_guard_eq(sympy.Eq(expr, concrete_val), True)
|
||||
|
||||
if concrete_val is sympy.true:
|
||||
g = expr
|
||||
elif concrete_val is sympy.false:
|
||||
g = sympy.Not(expr)
|
||||
else:
|
||||
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
|
||||
|
||||
if not self._suppress_guards_tls():
|
||||
tb = traceback.extract_stack()[:-1]
|
||||
stack = ''.join(traceback.format_list(tb))
|
||||
guard = ShapeGuard(g, stack)
|
||||
self.guards.append(guard)
|
||||
except Exception:
|
||||
self.remove_fx_node(node)
|
||||
raise
|
||||
else:
|
||||
self.log.debug("eval %s [guard suppressed]", g)
|
||||
if not self._suppress_guards_tls():
|
||||
assert guard is not None
|
||||
assert tb is not None
|
||||
|
||||
self.refine_ranges(guard)
|
||||
|
||||
if self.log.isEnabledFor(logging.INFO):
|
||||
for frame in reversed(tb):
|
||||
if frame.filename not in uninteresting_files():
|
||||
break
|
||||
|
||||
# NB: this stack is truncated, but it's fine because the main
|
||||
# stack_info will give you the rest of the info you need
|
||||
maybe_user_loc = ""
|
||||
user_tb = TracingContext.extract_stack()
|
||||
if user_tb:
|
||||
maybe_user_loc = " at " + format_frame(user_tb[-1])
|
||||
|
||||
is_debug = self.log.isEnabledFor(logging.DEBUG)
|
||||
maybe_extra_debug = ""
|
||||
if is_debug and user_tb:
|
||||
maybe_extra_debug = (
|
||||
'\nUser Stack (most recent call last):\n' +
|
||||
' (snipped, see stack below for prefix)\n' +
|
||||
''.join(traceback.format_list(user_tb))
|
||||
)
|
||||
self.log.info(
|
||||
"eval %s [guard added]%s (%s)%s",
|
||||
g,
|
||||
maybe_user_loc,
|
||||
format_frame(frame),
|
||||
maybe_extra_debug,
|
||||
stack_info=is_debug,
|
||||
)
|
||||
else:
|
||||
self.log.debug("eval %s [guard suppressed]", g)
|
||||
|
||||
return concrete_val
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import operator
|
||||
import sympy
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.exc import TorchDynamoException
|
||||
@ -524,31 +524,6 @@ try:
|
||||
assert r == z3.unsat
|
||||
log.debug("translation validation: success")
|
||||
|
||||
|
||||
class ValidationException(TorchDynamoException):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
assertions: Iterable[z3.ExprRef],
|
||||
target_exprs: Iterable[z3.ExprRef],
|
||||
failed_source_exprs: Iterable[z3.ExprRef]
|
||||
) -> None:
|
||||
model_str = self._joinlines(model, lambda sym: f"{sym}: {model[sym]}")
|
||||
assertions_str = self._joinlines(assertions, z3str)
|
||||
target_exprs_str = self._joinlines(target_exprs, z3str)
|
||||
failed_source_exprs_str = self._joinlines(failed_source_exprs, z3str)
|
||||
|
||||
super().__init__(
|
||||
"translation validation failed.\n\n"
|
||||
"Model:\n" + model_str + "\n\n"
|
||||
"Assertions:\n" + assertions_str + "\n\n"
|
||||
"Target Expressions:\n" + target_exprs_str + "\n\n"
|
||||
"Failed Source Expressions:\n" + failed_source_exprs_str + "\n\n"
|
||||
)
|
||||
|
||||
def _joinlines(self, xs: Iterable[Any], f: Callable[[Any], str] = lambda x: x) -> str:
|
||||
return "\n".join(f" ==> {f(x)}" for x in xs)
|
||||
|
||||
except ImportError:
|
||||
_HAS_Z3 = False
|
||||
else:
|
||||
@ -573,5 +548,29 @@ def assert_z3_installed_if_tv_set():
|
||||
"z3-solver or disable translation validation."
|
||||
)
|
||||
|
||||
|
||||
class ValidationException(TorchDynamoException):
|
||||
def __init__(self, model, assertions, target_exprs, failed_source_exprs):
|
||||
assert _HAS_Z3
|
||||
|
||||
def symbolstr(sym) -> str:
|
||||
return f"{sym}: {model[sym]}"
|
||||
|
||||
def joinlines(xs) -> str:
|
||||
return "\n".join(f" ==> {x}" for x in xs)
|
||||
|
||||
model_str = joinlines(map(symbolstr, model))
|
||||
assertions_str = joinlines(map(z3str, assertions))
|
||||
target_exprs_str = joinlines(map(z3str, target_exprs))
|
||||
failed_source_exprs_str = joinlines(map(z3str, failed_source_exprs))
|
||||
|
||||
super().__init__(
|
||||
"translation validation failed.\n\n"
|
||||
"Model:\n" + model_str + "\n\n"
|
||||
"Assertions:\n" + assertions_str + "\n\n"
|
||||
"Target Expressions:\n" + target_exprs_str + "\n\n"
|
||||
"Failed Source Expressions:\n" + failed_source_exprs_str
|
||||
)
|
||||
|
||||
# Checks when this module is loaded.
|
||||
assert_z3_installed_if_tv_set()
|
||||
|
||||
Reference in New Issue
Block a user