mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
kill allow_complex_guards_as_runtime_asserts (#160198)
Summary: Since `allow_complex_guards_as_runtime_asserts` is now sync'd with `prefer_deferred_runtime_asserts_over_guards`, we can kill the former (especially since it was a export-only concept). Test Plan: updated tests Rollback Plan: Differential Revision: D79903317 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160198 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa76256603
commit
196232bb93
@ -10916,8 +10916,8 @@ def ___make_guard_fn():
|
|||||||
ShapeEnv not equal: field values don't match:
|
ShapeEnv not equal: field values don't match:
|
||||||
|
|
||||||
==> settings: values don't match.
|
==> settings: values don't match.
|
||||||
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
|
> Left: ShapeEnvSettings(allow_scalar_outputs=False, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
|
||||||
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, allow_complex_guards_as_runtime_asserts=False, trace_asserts=False)
|
> Right: ShapeEnvSettings(allow_scalar_outputs=True, allow_dynamic_output_shape_ops=True, assume_static_by_default=False, specialize_zero_one=True, duck_shape=True, prefer_deferred_runtime_asserts_over_guards=False, trace_asserts=False)
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
self._replay_and_check(main)
|
self._replay_and_check(main)
|
||||||
|
@ -5514,11 +5514,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||||||
dim0_x = torch.export.Dim("dim0_x", min=3)
|
dim0_x = torch.export.Dim("dim0_x", min=3)
|
||||||
dim1_x = torch.export.Dim("dim1_x", max=8000)
|
dim1_x = torch.export.Dim("dim1_x", max=8000)
|
||||||
dynamic_shapes = {"x": (dim0_x, dim1_x)}
|
dynamic_shapes = {"x": (dim0_x, dim1_x)}
|
||||||
em = torch.export._trace._export(
|
em = torch.export.export(
|
||||||
m,
|
m,
|
||||||
(a,),
|
(a,),
|
||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
)
|
)
|
||||||
em.module()(torch.randn(4, 3))
|
em.module()(torch.randn(4, 3))
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
@ -13402,7 +13402,7 @@ def forward(self, x, y):
|
|||||||
|
|
||||||
def test_disable_forced_specializations_ok(self):
|
def test_disable_forced_specializations_ok(self):
|
||||||
# check that we don't force specialization, and defer to runtime asserts
|
# check that we don't force specialization, and defer to runtime asserts
|
||||||
# with allow_complex_guards_as_runtime_asserts=True to successfully export
|
# with prefer_deferred_runtime_asserts_over_guards=True to successfully export
|
||||||
# case 1: modulo guards
|
# case 1: modulo guards
|
||||||
from torch.export import dims
|
from torch.export import dims
|
||||||
|
|
||||||
@ -13412,11 +13412,11 @@ def forward(self, x, y):
|
|||||||
|
|
||||||
inputs = (torch.randn(10, 72),)
|
inputs = (torch.randn(10, 72),)
|
||||||
dx, dy = dims("dx", "dy")
|
dx, dy = dims("dx", "dy")
|
||||||
ep = torch.export._trace._export(
|
ep = torch.export.export(
|
||||||
Mod4Reshape(),
|
Mod4Reshape(),
|
||||||
inputs,
|
inputs,
|
||||||
dynamic_shapes={"x": (dx, dy)},
|
dynamic_shapes={"x": (dx, dy)},
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
)
|
)
|
||||||
out1 = ep.module()(torch.randn(8, 7))
|
out1 = ep.module()(torch.randn(8, 7))
|
||||||
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
|
self.assertEqual(out1.shape, torch.ones(7, 4, 2).shape)
|
||||||
@ -13446,11 +13446,11 @@ def forward(self, x, y):
|
|||||||
|
|
||||||
for private_api in (True, False):
|
for private_api in (True, False):
|
||||||
if private_api:
|
if private_api:
|
||||||
ep = torch.export._trace._export(
|
ep = torch.export.export(
|
||||||
FreeReshape(),
|
FreeReshape(),
|
||||||
inputs,
|
inputs,
|
||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
|
ep = export(FreeReshape(), inputs, dynamic_shapes=dynamic_shapes)
|
||||||
@ -13487,11 +13487,11 @@ def forward(self, x, y):
|
|||||||
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
|
"x": (Dim("dx0", min=2), Dim("dx1", min=2), Dim("dx2", min=2)),
|
||||||
"y": (Dim("dy", min=8),),
|
"y": (Dim("dy", min=8),),
|
||||||
}
|
}
|
||||||
ep = torch.export._trace._export(
|
ep = torch.export.export(
|
||||||
Reshape3d(),
|
Reshape3d(),
|
||||||
inputs,
|
inputs,
|
||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
)
|
)
|
||||||
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
|
out1 = ep.module()(torch.randn(9, 7, 2), torch.randn(126))
|
||||||
self.assertEqual(out1.shape, torch.ones(126).shape)
|
self.assertEqual(out1.shape, torch.ones(126).shape)
|
||||||
@ -13613,11 +13613,11 @@ def forward(self, x, y):
|
|||||||
model = Model()
|
model = Model()
|
||||||
x = torch.rand(1024, 20, 16)
|
x = torch.rand(1024, 20, 16)
|
||||||
dynamic_shapes = {"x": {0: Dim("batch")}}
|
dynamic_shapes = {"x": {0: Dim("batch")}}
|
||||||
ep = torch.export._trace._export(
|
ep = torch.export.export(
|
||||||
model,
|
model,
|
||||||
(x,),
|
(x,),
|
||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
)
|
)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
@ -13690,11 +13690,11 @@ def forward(self, x, y):
|
|||||||
|
|
||||||
inputs = (torch.randn(6), torch.randn(12))
|
inputs = (torch.randn(6), torch.randn(12))
|
||||||
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
|
dynamic_shapes = {"x": [Dim("dx", min=4)], "y": [Dim("dy", min=4)]}
|
||||||
ep = torch.export._trace._export(
|
ep = torch.export.export(
|
||||||
Foo(),
|
Foo(),
|
||||||
inputs,
|
inputs,
|
||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
)
|
)
|
||||||
# check forward pass
|
# check forward pass
|
||||||
out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
|
out0, out1 = ep.module()(torch.randn(9), torch.randn(27))
|
||||||
@ -13729,7 +13729,7 @@ def forward(self, x, y):
|
|||||||
Foo(),
|
Foo(),
|
||||||
inputs,
|
inputs,
|
||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
).run_decompositions()
|
).run_decompositions()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -14141,11 +14141,11 @@ graph():
|
|||||||
|
|
||||||
inputs = (torch.randn(5), torch.randn(3))
|
inputs = (torch.randn(5), torch.randn(3))
|
||||||
shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
|
shapes = {"x": (Dim("dx"),), "y": (Dim("dy"),)}
|
||||||
ep = torch.export._trace._export(
|
ep = torch.export.export(
|
||||||
Foo(),
|
Foo(),
|
||||||
inputs,
|
inputs,
|
||||||
dynamic_shapes=shapes,
|
dynamic_shapes=shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
)
|
)
|
||||||
# count 2 pow nodes, 2 sym_size.int nodes
|
# count 2 pow nodes, 2 sym_size.int nodes
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -14944,11 +14944,11 @@ def forward(self, x):
|
|||||||
|
|
||||||
for private_api in (True, False):
|
for private_api in (True, False):
|
||||||
if private_api:
|
if private_api:
|
||||||
ep = torch.export._trace._export(
|
ep = torch.export.export(
|
||||||
ModConstraint(),
|
ModConstraint(),
|
||||||
(torch.randn(3, 4),),
|
(torch.randn(3, 4),),
|
||||||
dynamic_shapes={"x": (dynamic, dynamic)},
|
dynamic_shapes={"x": (dynamic, dynamic)},
|
||||||
allow_complex_guards_as_runtime_asserts=True,
|
prefer_deferred_runtime_asserts_over_guards=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ep = export(
|
ep = export(
|
||||||
@ -14962,7 +14962,7 @@ def forward(self, x):
|
|||||||
for node in ep.graph.nodes
|
for node in ep.graph.nodes
|
||||||
].count(True)
|
].count(True)
|
||||||
if private_api:
|
if private_api:
|
||||||
self.assertEqual(num_asserts, 7)
|
self.assertEqual(num_asserts, 6)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",
|
r"Runtime assertion failed for expression Eq\(Mod\(s27\*s77, s77 - 1\), 0\)",
|
||||||
|
@ -258,12 +258,6 @@ capture_dynamic_output_shape_ops = (
|
|||||||
# hybrid backed unbacked symints
|
# hybrid backed unbacked symints
|
||||||
prefer_deferred_runtime_asserts_over_guards = False
|
prefer_deferred_runtime_asserts_over_guards = False
|
||||||
|
|
||||||
# For complex dynamic shapes guards that we're unable to specify with dynamo/export's
|
|
||||||
# range constraints + dims + derived dims language, we raise constraint violation
|
|
||||||
# errors or specialize by default. If set to True, this flag avoids crashing/specialization,
|
|
||||||
# and allows complex guards as runtime assertions in the graph.
|
|
||||||
allow_complex_guards_as_runtime_asserts = False
|
|
||||||
|
|
||||||
# By default, dynamo will treat all ints as backed SymInts, which means (1) it
|
# By default, dynamo will treat all ints as backed SymInts, which means (1) it
|
||||||
# will wait to see the int change over multiple runs before generalizing and
|
# will wait to see the int change over multiple runs before generalizing and
|
||||||
# (2) it will still always 0/1 specialize an int. When true, this knob
|
# (2) it will still always 0/1 specialize an int. When true, this knob
|
||||||
|
@ -1734,7 +1734,6 @@ def export(
|
|||||||
same_signature: bool = True,
|
same_signature: bool = True,
|
||||||
disable_constraint_solver: bool = False,
|
disable_constraint_solver: bool = False,
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
|
||||||
_log_export_usage: bool = True,
|
_log_export_usage: bool = True,
|
||||||
constraints: Optional[list[Constraint]] = None,
|
constraints: Optional[list[Constraint]] = None,
|
||||||
**extra_kwargs: Any,
|
**extra_kwargs: Any,
|
||||||
@ -1961,7 +1960,6 @@ def export(
|
|||||||
capture_dynamic_output_shape_ops=True,
|
capture_dynamic_output_shape_ops=True,
|
||||||
capture_scalar_outputs=True,
|
capture_scalar_outputs=True,
|
||||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
|
||||||
),
|
),
|
||||||
_compiling_state_context(),
|
_compiling_state_context(),
|
||||||
):
|
):
|
||||||
|
@ -468,7 +468,6 @@ class OutputGraph(OutputGraphGuardsState):
|
|||||||
allow_scalar_outputs=config.capture_scalar_outputs,
|
allow_scalar_outputs=config.capture_scalar_outputs,
|
||||||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||||
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
|
||||||
allow_complex_guards_as_runtime_asserts=config.allow_complex_guards_as_runtime_asserts,
|
|
||||||
co_fields=self.co_fields,
|
co_fields=self.co_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -330,7 +330,7 @@ def make_fake_inputs(
|
|||||||
args,
|
args,
|
||||||
kwargs,
|
kwargs,
|
||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=False,
|
prefer_deferred_runtime_asserts_over_guards=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given an nn module, example inputs, and constraints, return a new fake mode,
|
Given an nn module, example inputs, and constraints, return a new fake mode,
|
||||||
@ -382,8 +382,7 @@ def make_fake_inputs(
|
|||||||
shape_env=ShapeEnv(
|
shape_env=ShapeEnv(
|
||||||
tracked_fakes=[],
|
tracked_fakes=[],
|
||||||
co_fields=co_fields,
|
co_fields=co_fields,
|
||||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
|
||||||
trace_asserts=True,
|
trace_asserts=True,
|
||||||
),
|
),
|
||||||
allow_non_fake_inputs=True,
|
allow_non_fake_inputs=True,
|
||||||
|
@ -158,7 +158,7 @@ def export_for_training(
|
|||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
strict=strict,
|
strict=strict,
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -282,7 +282,7 @@ def export(
|
|||||||
strict=strict,
|
strict=strict,
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
pre_dispatch=True,
|
pre_dispatch=True,
|
||||||
allow_complex_guards_as_runtime_asserts=prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
draft_export_msg = (
|
draft_export_msg = (
|
||||||
|
@ -750,7 +750,7 @@ def _export_to_torch_ir(
|
|||||||
*,
|
*,
|
||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
disable_constraint_solver: bool = False,
|
disable_constraint_solver: bool = False,
|
||||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
restore_fqn: bool = True,
|
restore_fqn: bool = True,
|
||||||
_log_export_usage: bool = True,
|
_log_export_usage: bool = True,
|
||||||
same_signature: bool = True,
|
same_signature: bool = True,
|
||||||
@ -810,10 +810,7 @@ def _export_to_torch_ir(
|
|||||||
assume_static_by_default=True,
|
assume_static_by_default=True,
|
||||||
tracing_mode="symbolic",
|
tracing_mode="symbolic",
|
||||||
disable_constraint_solver=disable_constraint_solver,
|
disable_constraint_solver=disable_constraint_solver,
|
||||||
# currently the following 2 flags are tied together for export purposes,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
# but untangle for sake of dynamo export api
|
|
||||||
prefer_deferred_runtime_asserts_over_guards=allow_complex_guards_as_runtime_asserts,
|
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
|
||||||
_log_export_usage=_log_export_usage,
|
_log_export_usage=_log_export_usage,
|
||||||
same_signature=same_signature,
|
same_signature=same_signature,
|
||||||
)(
|
)(
|
||||||
@ -1402,7 +1399,7 @@ def _strict_export(
|
|||||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
||||||
preserve_module_call_signature: tuple[str, ...],
|
preserve_module_call_signature: tuple[str, ...],
|
||||||
orig_in_spec: TreeSpec,
|
orig_in_spec: TreeSpec,
|
||||||
allow_complex_guards_as_runtime_asserts: bool,
|
prefer_deferred_runtime_asserts_over_guards: bool,
|
||||||
_to_aten_func: Callable,
|
_to_aten_func: Callable,
|
||||||
) -> ExportArtifact:
|
) -> ExportArtifact:
|
||||||
"""
|
"""
|
||||||
@ -1416,7 +1413,7 @@ def _strict_export(
|
|||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
restore_fqn=False, # don't need to restore because we will do it later
|
restore_fqn=False, # don't need to restore because we will do it later
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
_log_export_usage=False,
|
_log_export_usage=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1859,7 +1856,7 @@ def _non_strict_export(
|
|||||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]],
|
||||||
preserve_module_call_signature: tuple[str, ...],
|
preserve_module_call_signature: tuple[str, ...],
|
||||||
orig_in_spec: TreeSpec,
|
orig_in_spec: TreeSpec,
|
||||||
allow_complex_guards_as_runtime_asserts: bool,
|
prefer_deferred_runtime_asserts_over_guards: bool,
|
||||||
_to_aten_func: Callable,
|
_to_aten_func: Callable,
|
||||||
) -> ExportArtifact:
|
) -> ExportArtifact:
|
||||||
"""
|
"""
|
||||||
@ -1956,7 +1953,7 @@ def _non_strict_export(
|
|||||||
args,
|
args,
|
||||||
kwargs,
|
kwargs,
|
||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, # for shape env initialization
|
||||||
)
|
)
|
||||||
|
|
||||||
fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
|
fake_params_buffers = _fakify_params_buffers(fake_mode, mod)
|
||||||
@ -2037,7 +2034,7 @@ def _export_for_training(
|
|||||||
*,
|
*,
|
||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
) -> ExportedProgram:
|
) -> ExportedProgram:
|
||||||
global _EXPORT_MODULE_HIERARCHY
|
global _EXPORT_MODULE_HIERARCHY
|
||||||
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
_EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod)
|
||||||
@ -2062,7 +2059,7 @@ def _export_for_training(
|
|||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
orig_in_spec=orig_in_spec,
|
orig_in_spec=orig_in_spec,
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
_to_aten_func=_export_to_aten_ir_make_fx,
|
_to_aten_func=_export_to_aten_ir_make_fx,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2124,7 +2121,7 @@ def _export(
|
|||||||
strict: bool = True,
|
strict: bool = True,
|
||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
pre_dispatch: bool = False,
|
pre_dispatch: bool = False,
|
||||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
) -> ExportedProgram:
|
) -> ExportedProgram:
|
||||||
"""
|
"""
|
||||||
Traces either an nn.Module's forward function or just a callable with PyTorch
|
Traces either an nn.Module's forward function or just a callable with PyTorch
|
||||||
@ -2155,7 +2152,7 @@ def _export(
|
|||||||
preserve_module_call_signature: A list of submodule paths for which the original
|
preserve_module_call_signature: A list of submodule paths for which the original
|
||||||
calling conventions are preserved as metadata.
|
calling conventions are preserved as metadata.
|
||||||
|
|
||||||
allow_complex_guards_as_runtime_asserts:
|
prefer_deferred_runtime_asserts_over_guards:
|
||||||
With the current dynamic shapes language for dims and derived dims, we can run into constraints
|
With the current dynamic shapes language for dims and derived dims, we can run into constraints
|
||||||
that are not expressible with the language. For example, flattening a matrix and adding to a vector,
|
that are not expressible with the language. For example, flattening a matrix and adding to a vector,
|
||||||
both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
|
both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible.
|
||||||
@ -2199,7 +2196,7 @@ def _export(
|
|||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
strict=strict,
|
strict=strict,
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
)
|
)
|
||||||
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
dtrace_structured("exported_program", payload_fn=lambda: str(ep))
|
||||||
return ep
|
return ep
|
||||||
@ -2224,7 +2221,7 @@ def _export(
|
|||||||
dynamic_shapes=dynamic_shapes,
|
dynamic_shapes=dynamic_shapes,
|
||||||
preserve_module_call_signature=preserve_module_call_signature,
|
preserve_module_call_signature=preserve_module_call_signature,
|
||||||
orig_in_spec=original_in_spec,
|
orig_in_spec=original_in_spec,
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
_to_aten_func=functools.partial(
|
_to_aten_func=functools.partial(
|
||||||
_export_to_aten_ir,
|
_export_to_aten_ir,
|
||||||
pre_dispatch=pre_dispatch,
|
pre_dispatch=pre_dispatch,
|
||||||
|
@ -3536,7 +3536,6 @@ class ShapeEnvSettings:
|
|||||||
specialize_zero_one: bool
|
specialize_zero_one: bool
|
||||||
duck_shape: bool
|
duck_shape: bool
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool
|
prefer_deferred_runtime_asserts_over_guards: bool
|
||||||
allow_complex_guards_as_runtime_asserts: bool
|
|
||||||
trace_asserts: bool
|
trace_asserts: bool
|
||||||
|
|
||||||
|
|
||||||
@ -3674,10 +3673,6 @@ class ShapeEnv:
|
|||||||
# in guards is helpful, since these guards in some sense are overly
|
# in guards is helpful, since these guards in some sense are overly
|
||||||
# pedantic. See also https://github.com/pytorch/pytorch/issues/121749
|
# pedantic. See also https://github.com/pytorch/pytorch/issues/121749
|
||||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||||
# When True, does not emit or raise constraint violation errors on
|
|
||||||
# implicit guards generated by ops, and defers to runtime assertions
|
|
||||||
# in the graph instead. For export.
|
|
||||||
allow_complex_guards_as_runtime_asserts: bool = False,
|
|
||||||
# XXX Add any new settings that could affect FakeTensor evaluation
|
# XXX Add any new settings that could affect FakeTensor evaluation
|
||||||
# to: torch._subclasses.fake_tensor._ShapeEnvSettings
|
# to: torch._subclasses.fake_tensor._ShapeEnvSettings
|
||||||
trace_asserts: bool = False,
|
trace_asserts: bool = False,
|
||||||
@ -3694,7 +3689,6 @@ class ShapeEnv:
|
|||||||
specialize_zero_one=specialize_zero_one,
|
specialize_zero_one=specialize_zero_one,
|
||||||
duck_shape=duck_shape,
|
duck_shape=duck_shape,
|
||||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||||
allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
|
|
||||||
trace_asserts=trace_asserts,
|
trace_asserts=trace_asserts,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -3906,10 +3900,6 @@ class ShapeEnv:
|
|||||||
def prefer_deferred_runtime_asserts_over_guards(self) -> bool:
|
def prefer_deferred_runtime_asserts_over_guards(self) -> bool:
|
||||||
return self.settings.prefer_deferred_runtime_asserts_over_guards
|
return self.settings.prefer_deferred_runtime_asserts_over_guards
|
||||||
|
|
||||||
@property
|
|
||||||
def allow_complex_guards_as_runtime_asserts(self) -> bool:
|
|
||||||
return self.settings.allow_complex_guards_as_runtime_asserts
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch_source_specialization(
|
def patch_source_specialization(
|
||||||
self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr]
|
self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr]
|
||||||
@ -6659,7 +6649,7 @@ class ShapeEnv:
|
|||||||
assert isinstance(a, sympy.Symbol)
|
assert isinstance(a, sympy.Symbol)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.allow_complex_guards_as_runtime_asserts
|
self.prefer_deferred_runtime_asserts_over_guards
|
||||||
and not _is_supported_equivalence(tgt)
|
and not _is_supported_equivalence(tgt)
|
||||||
):
|
):
|
||||||
return # continuing leads to placeholder shapes having complex expressions that we can't resolve
|
return # continuing leads to placeholder shapes having complex expressions that we can't resolve
|
||||||
@ -7641,7 +7631,15 @@ class ShapeEnv:
|
|||||||
# is no longer necessary)
|
# is no longer necessary)
|
||||||
self._maybe_guard_rel(g)
|
self._maybe_guard_rel(g)
|
||||||
|
|
||||||
if not self.allow_complex_guards_as_runtime_asserts:
|
if (
|
||||||
|
torch.compiler.is_exporting()
|
||||||
|
and self.prefer_deferred_runtime_asserts_over_guards
|
||||||
|
):
|
||||||
|
# it's fine to defer simple guards here without checking,
|
||||||
|
# the _maybe_guard_rel() call above will set replacements if possible,
|
||||||
|
# and so the result here will be statically known
|
||||||
|
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
|
||||||
|
else:
|
||||||
# at this point, we've evaluated the concrete expr value, and have
|
# at this point, we've evaluated the concrete expr value, and have
|
||||||
# flipped/negated the guard if necessary. Now we know what to guard
|
# flipped/negated the guard if necessary. Now we know what to guard
|
||||||
# or defer to runtime assert on.
|
# or defer to runtime assert on.
|
||||||
@ -7650,11 +7648,6 @@ class ShapeEnv:
|
|||||||
)
|
)
|
||||||
self.guards.append(guard)
|
self.guards.append(guard)
|
||||||
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
self.axioms.update(dict(self.get_implications(self.simplify(g))))
|
||||||
else:
|
|
||||||
# it's fine to defer simple guards here without checking,
|
|
||||||
# the _maybe_guard_rel() call above will set replacements if possible,
|
|
||||||
# and so the result here will be statically known
|
|
||||||
self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}")
|
|
||||||
else:
|
else:
|
||||||
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
|
self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user