infer dynamic shapes through additional inputs (#150144)

Summary:
Instead of explicitly specifying dynamic shapes, it is possible to infer them from additional example inputs. Together with the example inputs provided to export, we can basically make any varying dim dynamic and keep any fixed dim static. This should be useful for prod scenarios that have access to tests and/or profiling data, yet are somewhat removed from the model authoring process.

However this alone is not satisfactory: the exported program by design has only one graph, representing one path through the model, and we cannot necessarily guarantee that this graph works for the additional example inputs because different guards might have been created if we had exported with them instead (corresponding to different traced paths). However, checking that the additional example inputs satisfy the guards created by the original export should be sufficient for generalization.

Now, while we don't preserve all guards in the exported program, we do check a subset of them as part of input matching. So we add a verification step at the end of export when such additional example inputs are provided. This should be enough for now.

Test Plan: added test (positive and negative cases)

Differential Revision: D72001771

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150144
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Avik Chaudhuri
2025-04-01 21:13:39 +00:00
committed by PyTorch MergeBot
parent 0d44a8aea1
commit b70d105c77
5 changed files with 154 additions and 3 deletions

View File

@ -51,13 +51,14 @@ __all__ = [
"unflatten",
"FlatArgsAdapter",
"UnflattenedModule",
"AdditionalInputs",
]
# To make sure export specific custom ops are loaded
import torch.export.custom_ops
from .decomp_utils import CustomDecompTable
from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection
from .dynamic_shapes import AdditionalInputs, Constraint, Dim, dims, ShapesCollection
from .exported_program import (
default_decompositions,
ExportedProgram,