mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
switch prefer_deferred_runtime_asserts_over_guards in export (#160111)
Summary: In preparation for checking shape guards in export, this PR effectively switches `prefer_deferred_runtime_asserts_over_guards` to `False`, matching Dynamo. Actually that's a lie: we switch it to `allow_complex_guards_as_runtime_asserts`, which is `False` by default but can be controlled via an internally API to be `True`. This makes the two flags synchronized, so we should be able to kill `allow_complex_guards_as_runtime_asserts` at this point. Test Plan: updated tests Rollback Plan: Differential Revision: D79734206 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160111 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
6b051d7de3
commit
12c0cf3fab
@ -318,10 +318,7 @@ def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
linear_weight = self.linear.weight
|
||||
linear_bias = self.linear.bias
|
||||
sym_size_int_2 = torch.ops.aten.sym_size.int(x, 1)
|
||||
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
|
||||
eq = sym_size_int_2 == 4; sym_size_int_2 = None
|
||||
_assert_scalar_default = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s27, 4) on node 'eq'"); eq = _assert_scalar_default = None
|
||||
return pytree.tree_unflatten((linear,), self._out_spec)""",
|
||||
)
|
||||
|
||||
|
@ -11381,7 +11381,6 @@ graph():
|
||||
ep.module()(4, torch.randn(4, 4))
|
||||
|
||||
@testing.expectedFailureCppRuntime
|
||||
@testing.expectedFailureRetraceabilityNonStrict # no runtime asserts added for assert x == 3
|
||||
def test_symint_input_ranges(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -11415,8 +11414,12 @@ graph():
|
||||
)
|
||||
constraints = list(ep.range_constraints.values())
|
||||
constraint = constraints[0]
|
||||
self.assertEqual(constraint.lower, 4)
|
||||
self.assertEqual(constraint.upper, 5)
|
||||
# retracebility does not remember the range asserts in the forward
|
||||
lower, upper = (
|
||||
(3, 10) if is_retracebility_test(self._testMethodName) else (4, 5)
|
||||
)
|
||||
self.assertEqual(constraint.lower, lower)
|
||||
self.assertEqual(constraint.upper, upper)
|
||||
|
||||
# While tracing the range was found to be bigger than the original range
|
||||
class M(torch.nn.Module):
|
||||
@ -13440,22 +13443,36 @@ def forward(self, x, y):
|
||||
"y": [Dim(f"dy{i}", min=2) for i in range(2)],
|
||||
"z": [Dim(f"dz{i}", min=4) for i in range(1)],
|
||||
}
|
||||
ep = torch.export._trace._export(
|
||||
FreeReshape(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
)
|
||||
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48))
|
||||
self.assertEqual(out1.shape, torch.ones(48).shape)
|
||||
out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40))
|
||||
self.assertEqual(out2.shape, torch.ones(40).shape)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
|
||||
): # fail only at runtime
|
||||
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)) # fail
|
||||
|
||||
for private_api in (True, False):
|
||||
if private_api:
|
||||
ep = torch.export._trace._export(
|
||||
FreeReshape(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
)
|
||||
else:
|
||||
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
|
||||
out1 = ep.module()(torch.randn(48, 1), torch.randn(4, 12), torch.randn(48))
|
||||
self.assertEqual(out1.shape, torch.ones(48).shape)
|
||||
out2 = ep.module()(torch.randn(5, 8), torch.randn(4, 10), torch.randn(40))
|
||||
self.assertEqual(out2.shape, torch.ones(40).shape)
|
||||
if private_api:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\((.*)\) on node '.*'",
|
||||
): # fail only at runtime
|
||||
ep.module()(
|
||||
torch.randn(5, 8), torch.randn(4, 5), torch.randn(30)
|
||||
) # fail
|
||||
else:
|
||||
# no runtime assert in exported module but it fails anyway with wrong inputs
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"The size of tensor a \(40\) must match the size of tensor b \(20\) at non-singleton dimension 0",
|
||||
):
|
||||
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30))
|
||||
|
||||
# case 3: 3d reshape (previously failing with different issue)
|
||||
class Reshape3d(torch.nn.Module):
|
||||
@ -14925,21 +14942,41 @@ def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.view(x.shape[0] - 1, -1)
|
||||
|
||||
ep = export(
|
||||
ModConstraint(),
|
||||
(torch.randn(3, 4),),
|
||||
dynamic_shapes={
|
||||
"x": (dynamic, dynamic),
|
||||
},
|
||||
)
|
||||
ep.module()(torch.randn(5, 8))
|
||||
num_asserts = [
|
||||
node.target == torch.ops.aten._assert_scalar.default
|
||||
for node in ep.graph.nodes
|
||||
].count(True)
|
||||
self.assertEqual(num_asserts, 2)
|
||||
with self.assertRaises(RuntimeError):
|
||||
ep.module()(torch.randn(4, 2))
|
||||
for private_api in (True, False):
|
||||
if private_api:
|
||||
ep = torch.export._trace._export(
|
||||
ModConstraint(),
|
||||
(torch.randn(3, 4),),
|
||||
dynamic_shapes={"x": (dynamic, dynamic)},
|
||||
allow_complex_guards_as_runtime_asserts=True,
|
||||
)
|
||||
else:
|
||||
ep = export(
|
||||
ModConstraint(),
|
||||
(torch.randn(3, 4),),
|
||||
dynamic_shapes={"x": (dynamic, dynamic)},
|
||||
)
|
||||
ep.module()(torch.randn(5, 8))
|
||||
num_asserts = [
|
||||
node.target == torch.ops.aten._assert_scalar.default
|
||||
for node in ep.graph.nodes
|
||||
].count(True)
|
||||
if private_api:
|
||||
self.assertEqual(num_asserts, 7)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",
|
||||
):
|
||||
ep.module()(torch.randn(4, 2))
|
||||
else:
|
||||
# no runtime assert in exported module
|
||||
self.assertEqual(num_asserts, 0)
|
||||
# but it fails anyway with wrong inputs
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"shape '\[3, -1\]' is invalid for input of size 8",
|
||||
):
|
||||
ep.module()(torch.randn(4, 2))
|
||||
|
||||
@testing.expectedFailureSerDer # T195866111
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
|
@ -933,7 +933,7 @@ def forward(self, x):
|
||||
fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count(
|
||||
torch.ops.aten.sym_size.int
|
||||
)
|
||||
self.assertEqual(fn_count_sym_size(unflat.graph), 3)
|
||||
self.assertEqual(fn_count_sym_size(unflat.graph), 1)
|
||||
self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1)
|
||||
self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0)
|
||||
|
||||
|
@ -159,7 +159,11 @@ class AOTIRunnerUtil:
|
||||
with torch.no_grad():
|
||||
# strict=False needs extra migration work
|
||||
ep = torch.export.export(
|
||||
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
|
||||
model,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
ep, inductor_configs=inductor_configs
|
||||
|
@ -382,7 +382,7 @@ def make_fake_inputs(
|
||||
shape_env=ShapeEnv(
|
||||
tracked_fakes=[],
|
||||
co_fields=co_fields,
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
trace_asserts=True,
|
||||
),
|
||||
|
@ -69,6 +69,7 @@ def export_for_training(
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -157,6 +158,7 @@ def export_for_training(
|
||||
dynamic_shapes,
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
|
||||
|
||||
@ -168,6 +170,7 @@ def export(
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
:func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing
|
||||
@ -279,6 +282,7 @@ def export(
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
pre_dispatch=True,
|
||||
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
except Exception as e:
|
||||
draft_export_msg = (
|
||||
|
@ -812,7 +812,7 @@ def _export_to_torch_ir(
|
||||
disable_constraint_solver=disable_constraint_solver,
|
||||
# currently the following 2 flags are tied together for export purposes,
|
||||
# but untangle for sake of dynamo export api
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
_log_export_usage=_log_export_usage,
|
||||
same_signature=same_signature,
|
||||
@ -2037,6 +2037,7 @@ def _export_for_training(
|
||||
*,
|
||||
strict: bool = True,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
||||
) -> ExportedProgram:
|
||||
global _EXPORT_MODULE_HIERARCHY
|
||||
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
||||
@ -2061,7 +2062,7 @@ def _export_for_training(
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
orig_in_spec=orig_in_spec,
|
||||
allow_complex_guards_as_runtime_asserts=False,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
_to_aten_func=_export_to_aten_ir_make_fx,
|
||||
)
|
||||
|
||||
@ -2198,6 +2199,7 @@ def _export(
|
||||
dynamic_shapes,
|
||||
strict=strict,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
||||
)
|
||||
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
||||
return ep
|
||||
|
Reference in New Issue
Block a user