mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
dynamic shapes builder API (#124898)
This PR introduces a new way of building `dynamic_shapes` for export. The idea is to build up a mapping from input tensors to the dynamic shapes that should be assigned to their corresponding fake tensors. This mapping is automatically converted to the current form of `dynamic_shapes`, which must exactly match the structure of inputs. We do this by using pytree utils. With the current `dynamic_shapes`, we had to be careful about user-defined classes that are registered with pytree, since such classes are not necessarily polymorphic containers; they may be fine containing tensors, but not dynamic shapes. Thus we had decided to allow input instances of such classes to be associated with dynamic shapes in flattened form. This decision needs to be mirrored in this PR as well. To make it easier to keep these code paths in sync, we refactor the current recursive procedure for associating inputs with dynamic shapes to use the same pytree utils. This needs minor fixes to a few tests where `dynamic_shapes` were not exactly matching the structure of inputs. Differential Revision: D56551992 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124898 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
31801918e9
commit
e7846447e0
@ -61,7 +61,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
|
||||
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim, ShapesCollection
|
||||
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
|
||||
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
|
||||
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
|
||||
|
Reference in New Issue
Block a user