dynamic shapes builder API (#124898)

This PR introduces a new way of building `dynamic_shapes` for export. The idea is to build up a mapping from input tensors to the dynamic shapes that should be assigned to their corresponding fake tensors.

This mapping is automatically converted to the current form of `dynamic_shapes`, which must exactly match the structure of inputs. We do this by using pytree utils.

With the current `dynamic_shapes`, we had to be careful about user-defined classes that are registered with pytree, since  such classes are not necessarily polymorphic containers; they may be fine containing tensors, but not dynamic shapes. Thus we had decided to allow input instances of such classes to be associated with dynamic shapes in flattened form. This decision needs to be mirrored in this PR as well. To make it easier to keep these code paths in sync, we refactor the current recursive procedure for associating inputs with dynamic shapes to use the same pytree utils. This needs minor fixes to a few tests where `dynamic_shapes` were not exactly matching the structure of inputs.

Differential Revision: D56551992

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124898
Approved by: https://github.com/zhxchen17
This commit is contained in:
Avik Chaudhuri
2024-04-30 03:59:49 +00:00
committed by PyTorch MergeBot
parent 31801918e9
commit e7846447e0
6 changed files with 254 additions and 75 deletions

View File

@ -61,7 +61,7 @@ __all__ = [
]
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim, ShapesCollection
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule