infer dynamic shapes through additional inputs (#150144)

Summary:
Instead of explicitly specifying dynamic shapes, it is possible to infer them from additional example inputs. Together with the example inputs provided to export, we can basically make any varying dim dynamic and keep any fixed dim static. This should be useful for prod scenarios that have access to tests and/or profiling data, yet are somewhat removed from the model authoring process.

However this alone is not satisfactory: the exported program by design has only one graph, representing one path through the model, and we cannot necessarily guarantee that this graph works for the additional example inputs because different guards might have been created if we had exported with them instead (corresponding to different traced paths). However, checking that the additional example inputs satisfy the guards created by the original export should be sufficient for generalization.

Now, while we don't preserve all guards in the exported program, we do check a subset of them as part of input matching. So we add a verification step at the end of export when such additional example inputs are provided. This should be enough for now.

Test Plan: added test (positive and negative cases)

Differential Revision: D72001771

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150144
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Avik Chaudhuri
2025-04-01 21:13:39 +00:00
committed by PyTorch MergeBot
parent 0d44a8aea1
commit b70d105c77
5 changed files with 154 additions and 3 deletions

View File

@ -797,6 +797,12 @@ API Reference
.. automethod:: dynamic_shapes
.. autoclass:: torch.export.dynamic_shapes.AdditionalInputs
.. automethod:: add
.. automethod:: dynamic_shapes
.. automethod:: verify
.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes
.. autoclass:: Constraint
.. autoclass:: ExportedProgram

View File

@ -3888,6 +3888,62 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
if node.op == "placeholder":
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
def test_dynamic_shapes_inferred_basic(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
# x and y[0] must have same dynamic shape (say `dim`) >= 3
tmp = (x + y[0])[:3]
# z["k"] must have static shape = 3
return tmp * z["k"]
m = M()
args = (torch.randn(4), [torch.randn(4)], {"k": torch.randn(3)})
additional_inputs = torch.export.AdditionalInputs()
# 4->5, 4->5, 3->3
good_args = (torch.randn(5), [torch.randn(5)], {"k": torch.randn(3)})
additional_inputs.add(good_args)
ep = export(m, args, dynamic_shapes=additional_inputs)
got_shapes = [
str(tuple(node.meta["val"].shape))
for node in ep.graph.find_nodes(op="placeholder")
]
dim = next(iter(ep.range_constraints.keys()))
expected_shapes = [f"({dim},)", f"({dim},)", "(3,)"]
self.assertEqual(got_shapes, expected_shapes)
def expect_error(bad_args, run_time_msg, compile_time_msg):
with self.assertRaisesRegex(RuntimeError, run_time_msg):
ep.module()(*bad_args)
additional_inputs = torch.export.AdditionalInputs()
additional_inputs.add(bad_args)
with self.assertRaisesRegex(RuntimeError, compile_time_msg):
export(m, args, dynamic_shapes=additional_inputs)
expect_error(
# 4->2, 4->2, 3->3
bad_args=(torch.randn(2), [torch.randn(2)], {"k": torch.randn(3)}),
run_time_msg="Expected input.*to be >= 3, but got 2",
compile_time_msg="Expected input.*to be >= 3, but got 2",
)
expect_error(
# 4->6, 4->7, 3->3
bad_args=(torch.randn(6), [torch.randn(7)], {"k": torch.randn(3)}),
run_time_msg="Expected input.*to be equal to 6, but got 7",
compile_time_msg="Expected input.*to be equal to 6, but got 7",
)
expect_error(
# 4->5, 4->5, 3->4
bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}),
run_time_msg="Expected input.*to be equal to 3, but got 4",
compile_time_msg=r"Constraints violated.*\n.*was inferred to be a constant \(3\)",
)
def test_mismatched_dynamic_shapes(self):
AUTO, STATIC = Dim.AUTO, Dim.STATIC

View File

@ -51,13 +51,14 @@ __all__ = [
"unflatten",
"FlatArgsAdapter",
"UnflattenedModule",
"AdditionalInputs",
]
# To make sure export specific custom ops are loaded
import torch.export.custom_ops
from .decomp_utils import CustomDecompTable
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection
from .dynamic_shapes import AdditionalInputs, Constraint, Dim, dims, ShapesCollection
from .exported_program import (
default_decompositions,
ExportedProgram,

View File

@ -1155,10 +1155,15 @@ def _process_export_inputs(mod, args, kwargs, dynamic_shapes):
kwargs = kwargs if kwargs is not None else {}
_, original_in_spec = pytree.tree_flatten((args, kwargs))
if isinstance(dynamic_shapes, torch.export.AdditionalInputs):
verify_additional_inputs = dynamic_shapes.verify
dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs)
else:
verify_additional_inputs = lambda ep: None # noqa: E731
if isinstance(dynamic_shapes, torch.export.ShapesCollection):
dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs)
return args, kwargs, original_in_spec, dynamic_shapes
return args, kwargs, original_in_spec, dynamic_shapes, verify_additional_inputs
def _get_module_call_graph(
@ -1971,6 +1976,7 @@ def _export_for_training(
kwargs,
orig_in_spec,
dynamic_shapes,
verify_additional_inputs,
) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
original_state_dict = _get_original_state_dict(mod)
@ -2033,6 +2039,7 @@ def _export_for_training(
verifiers=[TrainingIRVerifier],
)
verify_additional_inputs(exported_program)
return exported_program
@ -2132,6 +2139,7 @@ def _export(
kwargs,
original_in_spec,
dynamic_shapes,
verify_additional_inputs,
) = _process_export_inputs(mod, args, kwargs, dynamic_shapes)
original_state_dict = _get_original_state_dict(mod)
@ -2205,4 +2213,5 @@ def _export(
dtrace_structured("exported_program", payload_fn=lambda: str(exported_program))
verify_additional_inputs(exported_program)
return exported_program

View File

@ -34,6 +34,7 @@ __all__ = [
"Dim",
"dims",
"refine_dynamic_shapes_from_suggested_fixes",
"AdditionalInputs",
]
@ -713,6 +714,84 @@ class ShapesCollection:
return dynamic_shapes
class AdditionalInputs:
"""
Infers dynamic_shapes based on additional inputs.
This is useful particularly for deployment engineers who, on the one hand, may
have access to ample testing or profiling data that can provide a fair sense of
representative inputs for a model, but on the other hand, may not know enough
about the model to guess which input shapes should be dynamic.
Input shapes that are different than the original are considered dynamic; conversely,
those that are the same as the original are considered static. Moreover, we verify
that the additional inputs are valid for the exported program. This guarantees that
tracing with them instead of the original would have generated the same graph.
Example::
args0, kwargs0 = ... # example inputs for export
# other representative inputs that the exported program will run on
dynamic_shapes = torch.export.AdditionalInputs()
dynamic_shapes.add(args1, kwargs1)
...
dynamic_shapes.add(argsN, kwargsN)
torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes)
"""
def __init__(self):
self._examples = []
def add(self, args, kwargs=None):
"""
Additional input :func:`args` and :func:`kwargs`.
"""
assert type(args) is tuple, f"Representative args {args} must be a tuple"
assert (
kwargs is None or type(kwargs) is dict
), f"Representative kwargs {kwargs} must be None or a dict"
self._examples.append((args, kwargs))
def dynamic_shapes(self, m, args, kwargs=None):
"""
Infers a :func:`dynamic_shapes` pytree structure by merging shapes of the
original input :func:`args` and :func:`kwargs` and of each additional input
args and kwargs.
"""
dynamic_shapes, *other_dynamic_shapes = [
_tree_map_with_path(
lambda path, t: tuple(t.shape), _combine_args(m, args, kwargs)
)
for args, kwargs in [(args, kwargs), *self._examples]
]
return tree_map_with_path(
lambda path, dim, *other_dims: (
dim
if all(other_dim == dim for other_dim in other_dims)
else Dim.DYNAMIC
),
dynamic_shapes,
*other_dynamic_shapes,
is_leaf=lambda i: type(i) is int,
)
def verify(self, ep):
"""
Verifies that an exported program is valid for each additional input.
"""
epm = ep.module()
for args, kwargs in self._examples:
torch.export._unlift._check_input_constraints_pre_hook(
epm, args, kwargs or {}
)
def _warn_on_None_dynamic_shape_dimension():
msg = (
"Using None as a dynamic shape dimension is deprecated. "