mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[export] turn on hybrid symints by default (#130775)
Sets `prefer_deferred_runtime_asserts_over_guards=True` for export, so any guards emitted from `SymNode.expect_true` (for example, guards that are implicitly required to be true for an op to succeed) won't lead to constraint violations. Instead these should appear in the graph as runtime asserts, or potentially as replacement expressions for placeholder shapes. For example, this reshape op should emit s0 * s1 = s2, deferred as a runtime assert. ``` x = torch.randn(4, 8) # [s0, s1] y = torch.randn(32) # [s2] out = x.reshape(-1) + y # this emits Eq(s0 * s1, s2), and we represent y's shape as [s0*s1] in the graph. ``` However, other complex guards can still cause export to fail, for instance guards emitted from `SymNode.guard_bool/guard_size_oblivious` (e.g. explicit if-else conditions in user code or lower-level op implementations hit during tracing) can still raise constraint violations. These can be deferred with `allow_complex_guards_as_runtime_asserts=True`. We don't yet make this default, because while this makes export more likely to succeed, it results in non-trivial asserts being emitted that often represent specialization to a variant of the op, or checks related to 0/1 specialization. We also remove forced specializations for export and kill the `_disable_forced_specializations` flag - now any guard we can't express with Dims/DerivedDims either are handled with Hybrid SymInts, or should be resolved with rewriting or deferring. Follow up: Currently, `ShapeEnv._set_replacement()` is called for complex equality expressions (e.g. s2 -> s0*s1 in the example above), and the ExportedProgram stores `s0*s1` in the input placeholder. This isn't checked for validity when the program is run, so an option is to avoid replacement and/or runtime assert on equality. Pull Request resolved: https://github.com/pytorch/pytorch/pull/130775 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
22388ffe03
commit
745324e487
@ -2928,7 +2928,7 @@ def forward(self, x):
|
||||
dynamic_shapes = {"x": (dim0,)}
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"must be specialized.*guards generated.*too complex",
|
||||
r"Constraints violated \(dim0\)",
|
||||
):
|
||||
torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
@ -2936,7 +2936,7 @@ def forward(self, x):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
"Not all values.*satisfy the generated guard",
|
||||
r"Constraints violated \(dim0\)",
|
||||
):
|
||||
torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
|
@ -1905,17 +1905,19 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
dim0_x = torch.export.Dim("dim0_x", min=3)
|
||||
dim1_x = torch.export.Dim("dim1_x", max=8000)
|
||||
dynamic_shapes = {"x": (dim0_x, dim1_x)}
|
||||
em = torch.export._trace._export(
|
||||
m,
|
||||
(a,),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
)
|
||||
em.module()(torch.randn(4, 3))
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
(
|
||||
"Specializations unexpectedly required"
|
||||
".*\n.*\\[0\\] must be specialized to 3.*guards.*too complex(.*\n)*.*"
|
||||
"Suggested fixes:(.*\n)*.*"
|
||||
"dim0_x = 3(.*\n)*.*"
|
||||
"dim1_x = 2\\*_dim1_x"
|
||||
),
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\(Mod\(s0\*s1, s0 \- 1\), 0\)",
|
||||
):
|
||||
torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes)
|
||||
em.module()(torch.randn(4, 5))
|
||||
|
||||
dim0_x = None
|
||||
dim1_x = 2 * torch.export.Dim("_dim1_x", max=4000)
|
||||
dynamic_shapes = {"x": (dim0_x, dim1_x)}
|
||||
@ -5773,8 +5775,8 @@ def forward(self, x, y):
|
||||
export(f, (inputs,), dynamic_shapes=dynamic_shapes)
|
||||
|
||||
def test_disable_forced_specializations_ok(self):
|
||||
# check that _disable_forced_specializations and allow_complex_guards_as_runtime_asserts flags
|
||||
# both behave correctly, avoiding forced specializations and deferring to runtime.
|
||||
# check that we don't force specialization, and defer to runtime asserts
|
||||
# with allow_complex_guards_as_runtime_asserts=True to successfully export
|
||||
# case 1: modulo guards
|
||||
from torch.export import dims
|
||||
|
||||
@ -5784,25 +5786,6 @@ def forward(self, x, y):
|
||||
|
||||
inputs = (torch.randn(10, 72),)
|
||||
dx, dy = dims("dx", "dy")
|
||||
with self.assertRaisesRegex( # this will force specialize
|
||||
torch._dynamo.exc.UserError,
|
||||
r".*Specializations unexpectedly required(.*\n)*"
|
||||
r".*dx = .* must be specialized to 10 because the guards generated for it are too complex(.*\n)*"
|
||||
r".*dy = .* must be specialized to 72 because the guards generated for it are too complex(.*\n)*",
|
||||
):
|
||||
export(
|
||||
Mod4Reshape(),
|
||||
inputs,
|
||||
dynamic_shapes={"x": (dx, dy)},
|
||||
)
|
||||
|
||||
torch.export._trace._export( # just check this successfully compiles
|
||||
Mod4Reshape(),
|
||||
inputs,
|
||||
dynamic_shapes={"x": (dx, dy)},
|
||||
strict=False,
|
||||
_disable_forced_specializations=True,
|
||||
)
|
||||
ep = torch.export._trace._export(
|
||||
Mod4Reshape(),
|
||||
inputs,
|
||||
@ -5834,30 +5817,13 @@ def forward(self, x, y):
|
||||
"y": [Dim(f"dy{i}", min=2) for i in range(2)],
|
||||
"z": [Dim(f"dz{i}", min=4) for i in range(1)],
|
||||
}
|
||||
with self.assertRaisesRegex( # this will force specialize
|
||||
torch._dynamo.exc.UserError,
|
||||
r".*Specializations unexpectedly required(.*\n)*"
|
||||
r".*dx0 = .* must be specialized to 6 because the guards generated for it are too complex(.*\n)*"
|
||||
r".*dx1 = .* must be specialized to 8 because the guards generated for it are too complex(.*\n)*",
|
||||
):
|
||||
export(
|
||||
FreeReshape(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
torch.export._trace._export(
|
||||
FreeReshape(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
_disable_forced_specializations=True,
|
||||
)
|
||||
ep = torch.export._trace._export(
|
||||
FreeReshape(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
)
|
||||
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48))
|
||||
self.assertEqual(out1.shape, torch.ones(48).shape)
|
||||
out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40))
|
||||
@ -5881,28 +5847,6 @@ def forward(self, x, y):
|
||||
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
|
||||
"y": (Dim("dy", min=8),),
|
||||
}
|
||||
with self.assertRaisesRegex( # this will force specialize
|
||||
torch._dynamo.exc.UserError,
|
||||
r".*Specializations unexpectedly required(.*\n)*"
|
||||
r"Suggested fixes:(.*\n)*"
|
||||
r".*dx0 = 4(.*\n)*"
|
||||
r".*dx1 = 3(.*\n)*"
|
||||
r".*dx2 = 2(.*\n)*"
|
||||
r".*dy = 24(.*\n)*",
|
||||
):
|
||||
export(
|
||||
Reshape3d(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
|
||||
torch.export._trace._export(
|
||||
Reshape3d(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
_disable_forced_specializations=True,
|
||||
)
|
||||
ep = torch.export._trace._export(
|
||||
Reshape3d(),
|
||||
inputs,
|
||||
@ -5918,7 +5862,7 @@ def forward(self, x, y):
|
||||
ep.module()(torch.randn(4, 3, 2), torch.randn(10)) # fail
|
||||
|
||||
def test_disable_forced_specializations_errors(self):
|
||||
# check error messages with disable_forced_specializations = False/True
|
||||
# check error messages with hybrid symints
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, w, x, y, z):
|
||||
return w.reshape([-1]) + x, y + z # simple: s0*s1 = s2, s3 = s4
|
||||
@ -5935,34 +5879,17 @@ def forward(self, x, y):
|
||||
"y": [Dim("dy")], # y & z incorrect, export is supposed to fail.
|
||||
"z": [Dim("dz")], # suggested fix should be to match these up.
|
||||
}
|
||||
with self.assertRaisesRegex( # if allow = False, suggested fixes should specialize 3, 4, 12.
|
||||
torch._dynamo.exc.UserError,
|
||||
r".*Specializations unexpectedly required(.*\n)*"
|
||||
r"Suggested fixes:(.*\n)*"
|
||||
r".*dw0 = 3(.*\n)*"
|
||||
r".*dw1 = 4(.*\n)*"
|
||||
r".*dx0 = 12(.*\n)*"
|
||||
r".*dz = dy(.*\n)*",
|
||||
):
|
||||
torch.export._trace._export(
|
||||
Foo(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
_disable_forced_specializations=False,
|
||||
)
|
||||
with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize.
|
||||
torch._dynamo.exc.UserError,
|
||||
r".*Constraints violated(.*\n)*"
|
||||
r"Suggested fixes:(.*\n)*"
|
||||
r".*dz = dy(.*\n)*",
|
||||
) as msg:
|
||||
torch.export._trace._export(
|
||||
export(
|
||||
Foo(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=False,
|
||||
_disable_forced_specializations=True,
|
||||
)
|
||||
|
||||
# TODO requires_grad doesn't seem to work with serialization.
|
||||
@ -6276,6 +6203,39 @@ def forward(self, x, y):
|
||||
ep.graph_module.code
|
||||
)
|
||||
|
||||
def test_slice_with_floordiv(self):
|
||||
# slice operation emits runtime assert s0//2 <= s1
|
||||
class M1(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
d = x.size(0) // 2
|
||||
return y[d:]
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.m1 = M1()
|
||||
|
||||
def forward(self, x, y):
|
||||
d = x.size(0) // 2
|
||||
m1_res = self.m1(x, y)
|
||||
return y[d:] + m1_res
|
||||
|
||||
inputs = (torch.ones(10), torch.ones(10))
|
||||
d0 = torch.export.Dim("d0", max=2048)
|
||||
d1 = torch.export.Dim("d1", max=2048)
|
||||
ep = export(
|
||||
M(),
|
||||
inputs,
|
||||
dynamic_shapes=((d0,), (d1,)),
|
||||
)
|
||||
ep.module()(torch.ones(8), torch.ones(4))
|
||||
ep.module()(torch.ones(8), torch.ones(5))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression \(s0//2\) \<\= s1",
|
||||
):
|
||||
ep.module()(torch.ones(10), torch.ones(4))
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestOneOffModelExportResult(TestCase):
|
||||
|
@ -2,7 +2,7 @@
|
||||
import contextlib
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -166,7 +166,7 @@ def make_fake_inputs(
|
||||
shape_env=ShapeEnv(
|
||||
tracked_fakes=[],
|
||||
co_fields=co_fields,
|
||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
),
|
||||
allow_non_fake_inputs=True,
|
||||
@ -176,7 +176,7 @@ def make_fake_inputs(
|
||||
fake_mode = FakeTensorMode(
|
||||
shape_env=ShapeEnv(
|
||||
tracked_fakes=[],
|
||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
),
|
||||
allow_non_fake_inputs=True,
|
||||
@ -242,7 +242,6 @@ def produce_guards_and_solve_constraints(
|
||||
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
|
||||
equalities_inputs: EqualityConstraint,
|
||||
original_signature: inspect.Signature,
|
||||
_disable_forced_specializations: Optional[bool] = False,
|
||||
_is_torch_jit_trace=False,
|
||||
):
|
||||
"""
|
||||
@ -254,7 +253,6 @@ def produce_guards_and_solve_constraints(
|
||||
Additional inputs:
|
||||
equalities_inputs: the equality constraints to use for guards
|
||||
original_signature: the signature of the forward method
|
||||
_disable_forced_specializations: if True, avoids forced specializations
|
||||
"""
|
||||
shape_env = fake_mode.shape_env
|
||||
assert shape_env is not None
|
||||
@ -271,7 +269,6 @@ def produce_guards_and_solve_constraints(
|
||||
input_contexts=input_contexts,
|
||||
equalities_inputs=equalities_inputs,
|
||||
ignore_static=False,
|
||||
_disable_forced_specializations=_disable_forced_specializations,
|
||||
)
|
||||
except ConstraintViolationError as e:
|
||||
constraint_violation_error = e
|
||||
@ -284,9 +281,7 @@ def produce_guards_and_solve_constraints(
|
||||
# TODO(avik): Maybe record the constraint violation error instead and replay later?
|
||||
assert constraint_violation_error
|
||||
raise constraint_violation_error
|
||||
dim_constraints.solve(
|
||||
_disable_forced_specializations=_disable_forced_specializations
|
||||
)
|
||||
dim_constraints.solve()
|
||||
dim_constraints.remove_redundant_dynamic_results()
|
||||
forced_specializations = dim_constraints.forced_specializations()
|
||||
if not _is_torch_jit_trace:
|
||||
|
@ -172,12 +172,15 @@ _SYM_INT_OPS = {
|
||||
operator.sub,
|
||||
operator.floordiv,
|
||||
operator.mod,
|
||||
operator.pow,
|
||||
torch.sym_int,
|
||||
torch.sym_float,
|
||||
torch.sym_ite,
|
||||
torch.sym_max,
|
||||
torch.sym_min,
|
||||
torch.sym_sqrt,
|
||||
torch.ops.aten.sym_size.int,
|
||||
torch.ops.aten.sym_stride.int,
|
||||
}
|
||||
|
||||
|
||||
@ -215,11 +218,11 @@ def deserialize_device(d: Device) -> torch.device:
|
||||
|
||||
|
||||
def serialize_sym_int(s: Union[int, torch.SymInt]) -> SymInt:
|
||||
if isinstance(s, (torch.SymInt, int)):
|
||||
if isinstance(s, (torch.SymInt, sympy.Symbol, int)):
|
||||
if symbolic_shapes.is_concrete_int(s):
|
||||
return SymInt.create(as_int=int(s))
|
||||
else:
|
||||
assert isinstance(s, torch.SymInt)
|
||||
assert isinstance(s, (torch.SymInt, sympy.Symbol))
|
||||
if s.node.hint is None:
|
||||
return SymInt.create(as_expr=SymExpr(str(s)))
|
||||
else:
|
||||
@ -487,9 +490,13 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
if node.target is operator.getitem:
|
||||
return
|
||||
|
||||
if node.target in _SYM_INT_OPS:
|
||||
meta_val = node.meta.get("val")
|
||||
if (
|
||||
node.target in _SYM_INT_OPS
|
||||
or node.target in _SYM_BOOL_OPS
|
||||
or (meta_val is not None and isinstance(meta_val, (torch.SymInt, torch.SymBool)))
|
||||
):
|
||||
assert len(node.kwargs) == 0
|
||||
meta_val = node.meta["val"]
|
||||
ex_node = Node(
|
||||
target=self.serialize_operator(node.target),
|
||||
inputs=self.serialize_sym_op_inputs(node.target, node.args),
|
||||
@ -497,17 +504,8 @@ class GraphModuleSerializer(metaclass=Final):
|
||||
Argument.create(
|
||||
as_sym_int=self.serialize_sym_int_output(node.name, meta_val)
|
||||
)
|
||||
],
|
||||
metadata=self.serialize_metadata(node),
|
||||
)
|
||||
elif node.target in _SYM_BOOL_OPS:
|
||||
assert len(node.kwargs) == 0
|
||||
meta_val = node.meta["val"]
|
||||
ex_node = Node(
|
||||
target=self.serialize_operator(node.target),
|
||||
inputs=self.serialize_sym_op_inputs(node.target, node.args),
|
||||
outputs=[
|
||||
Argument.create(
|
||||
if (node.target in _SYM_INT_OPS or isinstance(meta_val, torch.SymInt))
|
||||
else Argument.create(
|
||||
as_sym_bool=self.serialize_sym_bool_output(node.name, meta_val)
|
||||
)
|
||||
],
|
||||
@ -1538,6 +1536,15 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
def deserialize_sym_bool(self, s: SymBool) -> Union[bool, torch.SymBool]:
|
||||
val = s.value
|
||||
if s.type == "as_expr":
|
||||
# first we sympify this just to access any untracked symbols
|
||||
expr = sympy.sympify(val.expr_str)
|
||||
for sym in expr.free_symbols:
|
||||
if (
|
||||
not isinstance(sym, sympy.Number)
|
||||
and str(sym) not in self.symbol_name_to_symbol
|
||||
):
|
||||
self.deserialize_sym_int(SymInt.create(as_expr=SymExpr(str(sym))))
|
||||
# then we sympify again using locals to correctly reify with the constructed symbols
|
||||
expr = sympy.sympify(val.expr_str, locals=self.symbol_name_to_symbol)
|
||||
return self.shape_env.create_symboolnode(expr)
|
||||
elif s.type == "as_bool":
|
||||
@ -1661,7 +1668,11 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||
return self.graph
|
||||
|
||||
def deserialize_node(self, serialized_node: Node, target: Callable) -> None:
|
||||
if target in _SYM_BOOL_OPS or target in _SYM_INT_OPS:
|
||||
if (
|
||||
target in _SYM_BOOL_OPS
|
||||
or target in _SYM_INT_OPS
|
||||
or target == torch.ops.aten.item.default # this can produce either SymInt or SymBool
|
||||
):
|
||||
name = serialized_node.outputs[0].value.as_name
|
||||
args = self.deserialize_sym_op_inputs(serialized_node.inputs)
|
||||
|
||||
|
@ -148,9 +148,9 @@ def _check_input_constraints_for_graph(
|
||||
)
|
||||
else:
|
||||
if arg_dim != node_dim:
|
||||
if isinstance(
|
||||
node_dim, torch.SymInt
|
||||
): # this means we deferred a guard from export analysis to runtime, let this pass
|
||||
if isinstance(node_dim, torch.SymInt):
|
||||
# this means we deferred a guard from export analysis to runtime, let this pass
|
||||
# we'll add a runtime assert checking equality to this replacement expression
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Expected input at {get_keystr(key_path)}.shape[{j}] to be equal to "
|
||||
|
@ -96,6 +96,11 @@ class ExportDynamoConfig:
|
||||
reorderable_logging_functions: Set[Callable] = dataclasses.field(
|
||||
default_factory=set
|
||||
)
|
||||
# Emit runtime asserts after AOTAutograd instead.
|
||||
# This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE,
|
||||
# but if we want to reason more about what guards/runtime asserts to emit,
|
||||
# this makes it a bit cleaner to do from the export side. Also no real point in running this twice.
|
||||
do_not_emit_runtime_asserts = True
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -549,7 +554,7 @@ def _export_to_torch_ir(
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
# currently the following 2 flags are tied together for export purposes,
|
||||
# but untangle for sake of dynamo export api
|
||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
_log_export_usage=_log_export_usage,
|
||||
same_signature=same_signature,
|
||||
@ -580,6 +585,7 @@ def _export_to_aten_ir(
|
||||
fake_kwargs,
|
||||
fake_params_buffers,
|
||||
constant_attrs: ConstantAttrMap,
|
||||
produce_guards_callback=None,
|
||||
*,
|
||||
transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later.
|
||||
pre_dispatch=False,
|
||||
@ -625,16 +631,27 @@ def _export_to_aten_ir(
|
||||
if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
|
||||
gm.meta.update(mod.meta)
|
||||
|
||||
# Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
|
||||
# Overwrite output specs afterwards.
|
||||
from torch._dynamo import config as _dynamo_config
|
||||
from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
# Run produce guards before we handle runtime asserts.
|
||||
# This means we run the export solver before the runtime asserts pass.
|
||||
# Right now this doesn't mean much - the export solver is only there for suggested fixes,
|
||||
# and we won't even get to constraint solving if that's needed.
|
||||
# But if in future we want to control what runtime asserts are emitted for export,
|
||||
# or rely on produce_guards + solver for some simplification on runtime asserts, this probably makes sense.
|
||||
if produce_guards_callback:
|
||||
try:
|
||||
produce_guards_callback(gm)
|
||||
except (ConstraintViolationError, ValueRangeError) as e:
|
||||
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
|
||||
|
||||
# Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
|
||||
# Overwrite output specs afterwards.
|
||||
flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs))
|
||||
fake_mode = detect_fake_mode(flat_fake_args)
|
||||
|
||||
if not _dynamo_config.do_not_emit_runtime_asserts:
|
||||
if not torch._dynamo.config.do_not_emit_runtime_asserts:
|
||||
stack_trace = (
|
||||
'File "torch/fx/passes/runtime_assert.py", line 24, '
|
||||
"in insert_deferred_runtime_asserts"
|
||||
@ -1100,7 +1117,6 @@ def _strict_export(
|
||||
original_state_dict: Dict[str, Any],
|
||||
orig_in_spec: TreeSpec,
|
||||
allow_complex_guards_as_runtime_asserts: bool,
|
||||
_disable_forced_specializations: Optional[bool],
|
||||
_is_torch_jit_trace: bool,
|
||||
) -> ExportArtifact:
|
||||
lower_to_aten = functools.partial(_export_to_aten_ir, pre_dispatch=pre_dispatch)
|
||||
@ -1114,7 +1130,6 @@ def _strict_export(
|
||||
original_state_dict=original_state_dict,
|
||||
orig_in_spec=orig_in_spec,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
_disable_forced_specializations=_disable_forced_specializations,
|
||||
_is_torch_jit_trace=_is_torch_jit_trace,
|
||||
lower_to_aten_callback=lower_to_aten,
|
||||
)
|
||||
@ -1130,7 +1145,6 @@ def _strict_export_lower_to_aten_ir(
|
||||
original_state_dict: Dict[str, Any],
|
||||
orig_in_spec: TreeSpec,
|
||||
allow_complex_guards_as_runtime_asserts: bool,
|
||||
_disable_forced_specializations: Optional[bool],
|
||||
_is_torch_jit_trace: bool,
|
||||
lower_to_aten_callback: Callable,
|
||||
) -> ExportArtifact:
|
||||
@ -1303,6 +1317,7 @@ def _export_to_aten_ir_make_fx(
|
||||
fake_kwargs,
|
||||
fake_params_buffers,
|
||||
constant_attrs: ConstantAttrMap,
|
||||
produce_guards_callback=None,
|
||||
transform=lambda x: x,
|
||||
) -> ATenExportArtifact:
|
||||
@contextmanager
|
||||
@ -1469,13 +1484,18 @@ def _export_to_aten_ir_make_fx(
|
||||
input_specs=input_specs, output_specs=output_specs
|
||||
)
|
||||
|
||||
# See comment in _export_to_aten_ir()
|
||||
if produce_guards_callback:
|
||||
try:
|
||||
produce_guards_callback(gm)
|
||||
except (ConstraintViolationError, ValueRangeError) as e:
|
||||
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
|
||||
|
||||
from torch._guards import detect_fake_mode
|
||||
|
||||
fake_mode = detect_fake_mode(flat_args)
|
||||
|
||||
from torch._dynamo import config as _dynamo_config
|
||||
|
||||
if not _dynamo_config.do_not_emit_runtime_asserts:
|
||||
if not torch._dynamo.config.do_not_emit_runtime_asserts:
|
||||
stack_trace = (
|
||||
'File "torch/fx/passes/runtime_assert.py", line 24, '
|
||||
"in insert_deferred_runtime_asserts"
|
||||
@ -1534,7 +1554,6 @@ def _non_strict_export(
|
||||
original_state_dict: Dict[str, Any],
|
||||
orig_in_spec: TreeSpec,
|
||||
allow_complex_guards_as_runtime_asserts: bool,
|
||||
_disable_forced_specializations: Optional[bool],
|
||||
_is_torch_jit_trace: bool,
|
||||
_is_training: bool = False,
|
||||
) -> ExportArtifact:
|
||||
@ -1625,6 +1644,16 @@ def _non_strict_export(
|
||||
|
||||
fake_params_buffers = make_fake_params_buffers(fake_mode, _get_params_buffers(mod))
|
||||
|
||||
def _produce_guards_callback(gm):
|
||||
return produce_guards_and_solve_constraints(
|
||||
fake_mode=fake_mode,
|
||||
gm=gm,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
equalities_inputs=equalities_inputs,
|
||||
original_signature=original_signature,
|
||||
_is_torch_jit_trace=_is_torch_jit_trace,
|
||||
)
|
||||
|
||||
with fake_mode:
|
||||
with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as (
|
||||
patched_mod,
|
||||
@ -1648,6 +1677,7 @@ def _non_strict_export(
|
||||
new_fake_kwargs,
|
||||
fake_params_buffers,
|
||||
new_fake_constant_attrs,
|
||||
produce_guards_callback=_produce_guards_callback,
|
||||
transform=_tuplify_outputs,
|
||||
)
|
||||
# aten_export_artifact.constants contains only fake script objects, we need to map them back
|
||||
@ -1656,19 +1686,6 @@ def _non_strict_export(
|
||||
for fqn, obj in aten_export_artifact.constants.items()
|
||||
}
|
||||
|
||||
try:
|
||||
produce_guards_and_solve_constraints(
|
||||
fake_mode,
|
||||
aten_export_artifact.gm,
|
||||
dynamic_shapes,
|
||||
equalities_inputs,
|
||||
original_signature,
|
||||
_disable_forced_specializations=_disable_forced_specializations,
|
||||
_is_torch_jit_trace=_is_torch_jit_trace,
|
||||
)
|
||||
except (ConstraintViolationError, ValueRangeError) as e:
|
||||
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
|
||||
|
||||
_rewrite_non_persistent_buffers(
|
||||
mod, aten_export_artifact.sig, aten_export_artifact.constants
|
||||
)
|
||||
@ -1733,7 +1750,6 @@ def _export_for_training(
|
||||
original_state_dict=original_state_dict,
|
||||
orig_in_spec=orig_in_spec,
|
||||
allow_complex_guards_as_runtime_asserts=False,
|
||||
_disable_forced_specializations=False,
|
||||
_is_torch_jit_trace=False,
|
||||
)
|
||||
|
||||
@ -1821,7 +1837,6 @@ def _export(
|
||||
preserve_module_call_signature: Tuple[str, ...] = (),
|
||||
pre_dispatch: bool = False,
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
_disable_forced_specializations: Optional[bool] = False,
|
||||
_is_torch_jit_trace: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
@ -1864,13 +1879,6 @@ def _export(
|
||||
Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints
|
||||
while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes.
|
||||
|
||||
_disable_forced_specializations:
|
||||
Similar to allow_complex_guards_as_runtime_asserts, but only avoids specializing to static values if set to True.
|
||||
For complex guards that don't specialize, this flag doesn't have any effect. Ideally this would be subsumed by
|
||||
allow_complex_guards_as_runtime_asserts, but this handles one additional case: single-variable equalities where
|
||||
the symbol is solvable for a concrete value (e.g. Eq(s0 // 4, 400) -> s0 = 1600). If set to True, this flag will
|
||||
avoid specializations. Direct equalities (e.g. s0 = 4), will still specialize.
|
||||
|
||||
Returns:
|
||||
An ExportedProgram containing the traced method.
|
||||
"""
|
||||
@ -1880,12 +1888,6 @@ def _export(
|
||||
f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}",
|
||||
)
|
||||
|
||||
if _disable_forced_specializations and strict:
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
"_disable_forced_specializations can be only be specified in non-strict mode.",
|
||||
)
|
||||
|
||||
global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY
|
||||
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
||||
|
||||
@ -1919,7 +1921,6 @@ def _export(
|
||||
original_state_dict,
|
||||
orig_in_spec,
|
||||
allow_complex_guards_as_runtime_asserts,
|
||||
_disable_forced_specializations,
|
||||
_is_torch_jit_trace,
|
||||
)
|
||||
# Decompose here for readability.
|
||||
|
@ -502,14 +502,13 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||
|
||||
# Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature.
|
||||
# Overwrite output specs afterwards.
|
||||
from torch._dynamo import config as _dynamo_config
|
||||
from torch._export.passes._node_metadata_hook import (
|
||||
_node_metadata_hook,
|
||||
_set_node_metadata_hook,
|
||||
)
|
||||
from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names
|
||||
|
||||
if not _dynamo_config.do_not_emit_runtime_asserts:
|
||||
if not torch._dynamo.config.do_not_emit_runtime_asserts:
|
||||
stack_trace = (
|
||||
'File "torch/fx/passes/runtime_assert.py", line 24, '
|
||||
"in insert_deferred_runtime_asserts"
|
||||
|
@ -1772,38 +1772,7 @@ class DimConstraints:
|
||||
self._inconsistencies.clear()
|
||||
raise ValueError(f"The following inconsistencies were found:\n{msg}")
|
||||
|
||||
def _force_specialization(self, s):
|
||||
val = self._var_to_val[s]
|
||||
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
|
||||
self._substitutions[s] = val
|
||||
|
||||
def _specialize_divisor_symbols(self):
|
||||
for expr in self._multivariate_inequalities:
|
||||
for atom in expr.atoms(FloorDiv, Mod):
|
||||
_, divisor = atom.args
|
||||
for s in divisor.free_symbols:
|
||||
self._force_specialization(s)
|
||||
|
||||
multivariate_inequalities = self._multivariate_inequalities
|
||||
self._multivariate_inequalities = set()
|
||||
for expr in multivariate_inequalities:
|
||||
self.add(expr.xreplace(self._substitutions))
|
||||
self._raise_inconsistencies()
|
||||
self._univariate_inequalities = {
|
||||
s: exprs
|
||||
for s, exprs in self._univariate_inequalities.items()
|
||||
if s not in self._substitutions
|
||||
}
|
||||
self._congruences = {
|
||||
s: congruences
|
||||
for s, congruences in self._congruences.items()
|
||||
if s not in self._substitutions
|
||||
}
|
||||
|
||||
def solve(
|
||||
self,
|
||||
_disable_forced_specializations=False,
|
||||
):
|
||||
def solve(self):
|
||||
"""Solve the system of constraint equations to find simplified constraints
|
||||
"""
|
||||
self._raise_inconsistencies()
|
||||
@ -1818,12 +1787,10 @@ class DimConstraints:
|
||||
assert isinstance(solution, sympy.Eq), f"Expected an equality constraint for {s}, got {solution}"
|
||||
symbol, val = solution.args
|
||||
assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}"
|
||||
# really don't force specializations here
|
||||
if not (_disable_forced_specializations and s in self._marked_dynamic):
|
||||
# because this is univariate, the solution is a specialization
|
||||
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
|
||||
# add this as a substitution to simplify other constraints
|
||||
self._substitutions[s] = val
|
||||
# because this is univariate, the solution is a specialization
|
||||
self._static_results.add(f"{self._dcp.symbol_to_source[s][0].name()} == {val}")
|
||||
# add this as a substitution to simplify other constraints
|
||||
self._substitutions[s] = val
|
||||
|
||||
# simplify multivariate inequalities: some of them will now become univariate!
|
||||
multivariate_inequalities = self._multivariate_inequalities
|
||||
@ -1832,9 +1799,6 @@ class DimConstraints:
|
||||
self.add(expr.xreplace({s: self._substitutions[s]}))
|
||||
self._raise_inconsistencies()
|
||||
|
||||
if not _disable_forced_specializations:
|
||||
self._specialize_divisor_symbols()
|
||||
|
||||
# solve linear congruences
|
||||
# NOTE(avik): We do not need to solve them for symbols that have already been specialized.
|
||||
reduced_congruences = self._reduce_congruences()
|
||||
@ -1850,9 +1814,6 @@ class DimConstraints:
|
||||
self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)]
|
||||
r = try_solve(sympy.Eq(base, divisor * tmp), s)
|
||||
self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1])))
|
||||
elif not _disable_forced_specializations:
|
||||
self._force_specialization(s)
|
||||
self._univariate_inequalities.pop(s, None)
|
||||
|
||||
# remaining symbols have only pure inequalities (no equalities)
|
||||
for s, exprs in self._univariate_inequalities.items():
|
||||
@ -1875,11 +1836,6 @@ class DimConstraints:
|
||||
symbolic_equivalences = self._symbolic_equivalences
|
||||
self._symbolic_equivalences = []
|
||||
for source, expr in symbolic_equivalences:
|
||||
if not _disable_forced_specializations and not _is_supported_equivalence(expr):
|
||||
for s in expr.free_symbols:
|
||||
self._force_specialization(s)
|
||||
sexpr = self._dcp._print_Symbol(s)
|
||||
self._dynamic_results = {r for r in self._dynamic_results if sexpr not in r}
|
||||
self.add_equality(source, expr.xreplace(self._substitutions))
|
||||
|
||||
# remaining symbolic equivalences become dynamic equality constraints
|
||||
@ -2893,6 +2849,7 @@ class ShapeEnv:
|
||||
we know statically is already True but we are checking it again in a way
|
||||
that is not clearly dischargeable.
|
||||
"""
|
||||
# self.prefer_deferred_runtime_asserts_over_guards = False
|
||||
self.runtime_asserts_frozen = True
|
||||
|
||||
def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]:
|
||||
@ -3656,7 +3613,6 @@ class ShapeEnv:
|
||||
# (See docs on EqualityConstraint for details of the encoding.)
|
||||
equalities_inputs: Optional[EqualityConstraint] = None,
|
||||
_simplified=False,
|
||||
_disable_forced_specializations=False,
|
||||
# Indicates if we should produce guards for known static values.
|
||||
ignore_static=True,
|
||||
) -> List[str]:
|
||||
@ -4096,13 +4052,12 @@ class ShapeEnv:
|
||||
constraints = symbol_to_constraints[symbol]
|
||||
for c in constraints:
|
||||
if isinstance(c, StrictMinMaxConstraint):
|
||||
if not _disable_forced_specializations:
|
||||
var_with_range = self._render_range_for_constraint_violation(source, c)
|
||||
msg = (
|
||||
f"Not all values of {var_with_range} "
|
||||
f"satisfy the generated guard {guard_expr}."
|
||||
)
|
||||
record_constraint_violation(c.warn_only, self._debug_name(source), msg)
|
||||
var_with_range = self._render_range_for_constraint_violation(source, c)
|
||||
msg = (
|
||||
f"Not all values of {var_with_range} "
|
||||
f"satisfy the generated guard {guard_expr}."
|
||||
)
|
||||
record_constraint_violation(c.warn_only, self._debug_name(source), msg)
|
||||
elif isinstance(c, RelaxedUnspecConstraint):
|
||||
# This is fine, we allow guards here as long as it
|
||||
# didn't constrain it to one value (we don't
|
||||
@ -4123,6 +4078,17 @@ class ShapeEnv:
|
||||
continue
|
||||
issue_guard(guard)
|
||||
|
||||
# Because there are guards that export's constraint solver can suggest good fixes for, that we may have
|
||||
# deferred as runtime asserts, and that produce_guards() alone won't do anything with (e.g. divisiblity guards),
|
||||
# we want to send runtime asserts to export's constraint solver too. These will still stay in the graph as asserts,
|
||||
# but export's constraint solver can decide whether to do anything with them (i.e. raise an error and provide
|
||||
# suggested fixes, or decide it's out of scope and leave as a runtime assert in the graph).
|
||||
for ra in self.deferred_runtime_asserts.get(None, []):
|
||||
if self._maybe_evaluate_static(ra.expr, axioms=()) is not None:
|
||||
continue
|
||||
expr = self.simplify(ra.expr)
|
||||
self.dim_constraints.add(expr)
|
||||
|
||||
# 3. Every symbol must be within its value range (this handles 0/1
|
||||
# specialization too).
|
||||
for symbol, sources in symbol_to_source.items():
|
||||
|
Reference in New Issue
Block a user