mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
dynamic shapes builder API (#124898)
This PR introduces a new way of building `dynamic_shapes` for export. The idea is to build up a mapping from input tensors to the dynamic shapes that should be assigned to their corresponding fake tensors. This mapping is automatically converted to the current form of `dynamic_shapes`, which must exactly match the structure of inputs. We do this by using pytree utils. With the current `dynamic_shapes`, we had to be careful about user-defined classes that are registered with pytree, since such classes are not necessarily polymorphic containers; they may be fine containing tensors, but not dynamic shapes. Thus we had decided to allow input instances of such classes to be associated with dynamic shapes in flattened form. This decision needs to be mirrored in this PR as well. To make it easier to keep these code paths in sync, we refactor the current recursive procedure for associating inputs with dynamic shapes to use the same pytree utils. This needs minor fixes to a few tests where `dynamic_shapes` were not exactly matching the structure of inputs. Differential Revision: D56551992 Pull Request resolved: https://github.com/pytorch/pytorch/pull/124898 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
31801918e9
commit
e7846447e0
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
58a412cb271a3f98ae2e01fd1d24bdbb66645d4e
|
||||
73b915b55d96553a0e370b2bab01f47b8c2a9e7c
|
||||
|
@ -685,6 +685,10 @@ API Reference
|
||||
.. autofunction:: register_dataclass
|
||||
.. autofunction:: torch.export.dynamic_shapes.Dim
|
||||
.. autofunction:: dims
|
||||
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection
|
||||
|
||||
.. automethod:: dynamic_shapes
|
||||
|
||||
.. autoclass:: Constraint
|
||||
.. autoclass:: ExportedProgram
|
||||
|
||||
|
@ -10,6 +10,7 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from re import escape
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch._dynamo as torchdynamo
|
||||
@ -124,6 +125,13 @@ def foo_functional(x):
|
||||
return a.cos()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Inp:
|
||||
x: Tensor
|
||||
y: List[Tensor]
|
||||
z: Dict[str, Tensor]
|
||||
|
||||
|
||||
NON_STRICT_SUFFIX = "_non_strict"
|
||||
RETRACEABILITY_SUFFIX = "_retraceability"
|
||||
|
||||
@ -1090,6 +1098,86 @@ class TestExport(TestCase):
|
||||
self.assertEqual(range_lower_bounds, [1, 2])
|
||||
self.assertEqual(range_upper_bounds, [2, 3])
|
||||
|
||||
def test_dynamic_shapes_builder_basic(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
return x + y[0] + z["k"]
|
||||
|
||||
m = M()
|
||||
|
||||
x = torch.randn(4)
|
||||
y = [torch.randn(4)]
|
||||
z = {"k": torch.randn(4)}
|
||||
args = (x, y, z)
|
||||
|
||||
shapes_collection = torch.export.ShapesCollection()
|
||||
dim = torch.export.Dim("dim", max=10)
|
||||
shapes_collection[x] = (dim,)
|
||||
shapes_collection[y[0]] = (dim,)
|
||||
shapes_collection[z["k"]] = (dim,)
|
||||
|
||||
ep = export(m, args, dynamic_shapes=shapes_collection)
|
||||
sym = next(iter(ep.range_constraints.keys()))
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
|
||||
|
||||
def test_dynamic_shapes_builder_kwargs(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y, z):
|
||||
return x + y[0] + z["k"]
|
||||
|
||||
m = M()
|
||||
|
||||
x = torch.randn(4)
|
||||
y = [torch.randn(4)]
|
||||
z = {"k": torch.randn(4)}
|
||||
args = (x,)
|
||||
kwargs = {"z": z, "y": y}
|
||||
|
||||
shapes_collection = torch.export.ShapesCollection()
|
||||
dim = torch.export.Dim("dim", max=10)
|
||||
shapes_collection[x] = (dim,)
|
||||
shapes_collection[y[0]] = (dim,)
|
||||
shapes_collection[z["k"]] = (dim,)
|
||||
|
||||
ep = export(m, args, kwargs=kwargs, dynamic_shapes=shapes_collection)
|
||||
sym = next(iter(ep.range_constraints.keys()))
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
|
||||
|
||||
# retracing doesn't seem to like dataclass registration,
|
||||
# raising a dynamo error in fx_pytree.tree_flatten_spec
|
||||
@testing.expectedFailureRetraceability
|
||||
def test_dynamic_shapes_builder_pytree(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp,
|
||||
serialized_type_name="test_dynamic_shapes_builder_pytree.Inp",
|
||||
)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, inp: Inp):
|
||||
return inp.x + inp.y[0] + inp.z["k"]
|
||||
|
||||
m = M()
|
||||
x = torch.randn(4)
|
||||
y = [torch.randn(4)]
|
||||
z = {"k": torch.randn(4)}
|
||||
args = (Inp(x, y, z),)
|
||||
|
||||
shapes_collection = torch.export.ShapesCollection()
|
||||
dim = torch.export.Dim("dim", max=10)
|
||||
shapes_collection[x] = (dim,)
|
||||
shapes_collection[y[0]] = (dim,)
|
||||
shapes_collection[z["k"]] = (dim,)
|
||||
|
||||
ep = export(m, args, dynamic_shapes=shapes_collection.dynamic_shapes(m, args))
|
||||
sym = next(iter(ep.range_constraints.keys()))
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
|
||||
|
||||
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -1537,7 +1625,7 @@ class TestExport(TestCase):
|
||||
return torch.matmul(inputs[0], inputs[1])
|
||||
|
||||
foo = Foo()
|
||||
inputs = ((torch.randn(10, 2, 3), torch.randn(10, 3, 4)),)
|
||||
inputs = ([torch.randn(10, 2, 3), torch.randn(10, 3, 4)],)
|
||||
batch = Dim("batch")
|
||||
efoo = export(
|
||||
foo, inputs, dynamic_shapes={"inputs": [{0: batch} for _ in range(2)]}
|
||||
@ -1553,6 +1641,11 @@ class TestExport(TestCase):
|
||||
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
|
||||
|
||||
# pass dynamic shapes of inputs [dataclass]
|
||||
|
||||
# TODO(avik): This part of the test should have failed both serde and retracing
|
||||
# but these failures are hidden because of the local import of `export` in this test.
|
||||
# The serde failure is benign, and easily avoided by moving the dataclass definition
|
||||
# to the top-level. OTOH the retracing failure needs further investigation.
|
||||
@dataclass
|
||||
class DataClass:
|
||||
a: Tensor
|
||||
|
@ -61,7 +61,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
|
||||
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim, ShapesCollection
|
||||
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
|
||||
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
|
||||
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule
|
||||
|
@ -922,8 +922,6 @@ def _export(
|
||||
Returns:
|
||||
An ExportedProgram containing the traced method.
|
||||
"""
|
||||
from .dynamic_shapes import _process_dynamic_shapes
|
||||
|
||||
if not isinstance(args, tuple):
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
@ -940,7 +938,8 @@ def _export(
|
||||
_EXPORT_FLAGS = flags
|
||||
|
||||
kwargs = kwargs or {}
|
||||
_process_dynamic_shapes(mod, args, kwargs, dynamic_shapes) # TODO(avik): remove
|
||||
if isinstance(dynamic_shapes, torch.export.ShapesCollection):
|
||||
dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs)
|
||||
|
||||
constant_attrs = _gather_constant_attrs(mod)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import SUPPORTED_NODES
|
||||
from torch.utils._pytree import _get_node_type, BUILTIN_TYPES, SUPPORTED_NODES, tree_map
|
||||
|
||||
from .exported_program import ExportedProgram
|
||||
|
||||
@ -553,72 +553,151 @@ def _process_equalities(
|
||||
derived_equalities.append((source, root, fn))
|
||||
|
||||
|
||||
def _tree_map(
|
||||
func: Callable[..., Any],
|
||||
tree: Any,
|
||||
*dynamic_shapes: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Customized tree_map for mapping pytrees to dynamic_shapes.
|
||||
|
||||
For built-in types (e.g., standard collections) this behaves exactly like tree_map.
|
||||
|
||||
OTOH for a user-defined class C registered with pytree, we cannot assume that a C
|
||||
containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not
|
||||
be a polymorphic container). In that case we use the flattened form of C instead.
|
||||
Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes).
|
||||
|
||||
Args:
|
||||
func: function to apply to each (int, float, str, bool, None, torch.Tensor)
|
||||
tree: input pytree
|
||||
dynamic_shapes: zero or more (typically one) dynamic_shapes to match
|
||||
|
||||
Returns:
|
||||
output pytree mapping func to each (int, float, str, bool, None, torch.Tensor)
|
||||
"""
|
||||
|
||||
def is_leaf(t):
|
||||
# BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types
|
||||
# registered with pytree. Types *not* in BUILTIN_TYPES include primitive types
|
||||
# (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES,
|
||||
# as well as user-defined classes registered with pytree, which are.
|
||||
return _get_node_type(t) not in BUILTIN_TYPES
|
||||
|
||||
def f(t, *dynamic_shapes):
|
||||
typ = _get_node_type(t)
|
||||
# typ is not in BUILTIN_TYPES
|
||||
if typ in SUPPORTED_NODES:
|
||||
# thus typ is a user-defined class registered with pytree,
|
||||
# in which case flatten and recurse
|
||||
return tree_map(
|
||||
f,
|
||||
SUPPORTED_NODES[typ].flatten_fn(t)[0],
|
||||
*dynamic_shapes,
|
||||
is_leaf=is_leaf,
|
||||
)
|
||||
else:
|
||||
return func(t, *dynamic_shapes)
|
||||
|
||||
return tree_map(f, tree, *dynamic_shapes, is_leaf=is_leaf)
|
||||
|
||||
|
||||
def _combine_args(f, args, kwargs):
|
||||
# combine args and kwargs following the signature of f, as it happens
|
||||
# in the body of f when called with *args, **kwargs
|
||||
if isinstance(f, ExportedProgram):
|
||||
f = f.module()
|
||||
signature = (
|
||||
inspect.signature(f.forward)
|
||||
if isinstance(f, torch.nn.Module)
|
||||
else inspect.signature(f)
|
||||
)
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
return signature.bind(*args, **kwargs).arguments
|
||||
|
||||
|
||||
class ShapesCollection:
|
||||
"""
|
||||
Builder for dynamic_shapes.
|
||||
Used to assign dynamic shape specifications to tensors that appear in inputs.
|
||||
|
||||
Example::
|
||||
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]})
|
||||
|
||||
dim = torch.export.Dim(...)
|
||||
dynamic_shapes = torch.export.ShapesCollection()
|
||||
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
|
||||
dynamic_shapes[tensor_y] = {0: dim * 2}
|
||||
# This is equivalent to the following (now auto-generated):
|
||||
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}
|
||||
|
||||
torch.export(..., args, dynamic_shapes=dynamic_shapes)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._shapes = {}
|
||||
|
||||
def __setitem__(self, t, shape):
|
||||
assert isinstance(
|
||||
t, torch.Tensor
|
||||
), f"Cannot assign shape to non-tensor type {type(t)}"
|
||||
# TODO(avik): check that shape is indeed a Shape
|
||||
t_id = id(t)
|
||||
if t_id in self._shapes:
|
||||
_shape = self._shapes[t_id]
|
||||
assert (
|
||||
shape == _shape
|
||||
), f"Shapes assigned to tensor do not match: expected {_shape}, got {shape}"
|
||||
else:
|
||||
self._shapes[id(t)] = shape
|
||||
|
||||
def __getitem__(self, t):
|
||||
t_id = id(t)
|
||||
if t_id in self._shapes:
|
||||
return self._shapes[t_id]
|
||||
else:
|
||||
return None
|
||||
|
||||
def __len__(self):
|
||||
return len(self._shapes)
|
||||
|
||||
def dynamic_shapes(self, m, args, kwargs=None):
|
||||
"""
|
||||
Generate dynamic_shapes.
|
||||
"""
|
||||
|
||||
t_ids = set()
|
||||
|
||||
def find_shape(t):
|
||||
t_id = id(t)
|
||||
if t_id in self._shapes:
|
||||
t_ids.add(t_id)
|
||||
return self._shapes[t_id]
|
||||
else:
|
||||
return None
|
||||
|
||||
combined_args = _combine_args(m, args, kwargs)
|
||||
dynamic_shapes = _tree_map(find_shape, combined_args)
|
||||
if any(t_id not in t_ids for t_id in self._shapes):
|
||||
raise ValueError(
|
||||
"Some tensors that were assigned shapes were not found in args. "
|
||||
"Maybe such tensors were copied when passing them as args? "
|
||||
"Maybe such tensors are contained in classes that were not registered with pytree?"
|
||||
)
|
||||
return dynamic_shapes
|
||||
|
||||
|
||||
def _process_dynamic_shapes(
|
||||
f: Callable,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
|
||||
) -> Optional[List[Constraint]]:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from torch._dynamo.exc import UserError, UserErrorType
|
||||
|
||||
if dynamic_shapes is None or len(dynamic_shapes) == 0:
|
||||
return None
|
||||
|
||||
kwargs = kwargs if kwargs is not None else {}
|
||||
|
||||
def tree_zip(combined_args, dynamic_shapes):
|
||||
if isinstance(combined_args, (tuple, list)):
|
||||
if not isinstance(dynamic_shapes, Sequence):
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, "
|
||||
f"got {dynamic_shapes} instead",
|
||||
)
|
||||
if len(combined_args) != len(dynamic_shapes):
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
f"Expected {dynamic_shapes} to have {len(combined_args)} items",
|
||||
)
|
||||
for i, shape in enumerate(dynamic_shapes):
|
||||
yield from tree_zip(combined_args[i], shape)
|
||||
elif isinstance(combined_args, dict):
|
||||
if not isinstance(dynamic_shapes, Mapping):
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, "
|
||||
f"got {dynamic_shapes} instead",
|
||||
)
|
||||
if len(combined_args) != len(dynamic_shapes):
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
f"Expected {dynamic_shapes} to have {len(combined_args)} items",
|
||||
)
|
||||
for k, shape in dynamic_shapes.items():
|
||||
yield from tree_zip(combined_args[k], shape)
|
||||
elif type(combined_args) in SUPPORTED_NODES:
|
||||
if not isinstance(dynamic_shapes, Sequence):
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
f"Expected dynamic_shapes of a user-registered class (e.g., "
|
||||
f"{type(combined_args)}) to be a Sequence that matches the "
|
||||
f"flattened structure, but got {dynamic_shapes} instead",
|
||||
)
|
||||
yield from tree_zip(
|
||||
SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0],
|
||||
dynamic_shapes,
|
||||
)
|
||||
elif isinstance(combined_args, torch.Tensor):
|
||||
yield (combined_args, dynamic_shapes)
|
||||
else:
|
||||
if dynamic_shapes is not None:
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
f"Expected dynamic_shapes of a {type(combined_args)} to be None, "
|
||||
f"got {dynamic_shapes} instead",
|
||||
)
|
||||
|
||||
# map of Dim names representing input shape dimensions to constraints on them
|
||||
symbols: Dict[str, List[Constraint]] = defaultdict(list)
|
||||
# track roots that do not directly represent input shape dimensions
|
||||
@ -765,21 +844,25 @@ def _process_dynamic_shapes(
|
||||
f"Unexpected dynamic_shape {shape} of Tensor, " "try None instead",
|
||||
)
|
||||
|
||||
import inspect
|
||||
def assoc_shapes(combined_args, dynamic_shapes):
|
||||
def assoc_shape(t, dynamic_shape):
|
||||
if isinstance(t, torch.Tensor):
|
||||
update_symbols(t, dynamic_shape)
|
||||
else:
|
||||
if dynamic_shape is not None:
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
f"Cannot associate shape {dynamic_shape} to non-tensor type {type(t)}, "
|
||||
f"expected None",
|
||||
)
|
||||
|
||||
if isinstance(f, ExportedProgram):
|
||||
f = f.module()
|
||||
signature = (
|
||||
inspect.signature(f.forward)
|
||||
if isinstance(f, torch.nn.Module)
|
||||
else inspect.signature(f)
|
||||
)
|
||||
combined_args = signature.bind(*args, **kwargs).arguments
|
||||
_tree_map(assoc_shape, combined_args, dynamic_shapes)
|
||||
|
||||
# This means user didn't specify dynamic shapes with argument names.
|
||||
combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment]
|
||||
for tensor, shape in tree_zip(combined_args, dynamic_shapes):
|
||||
update_symbols(tensor, shape)
|
||||
combined_args = _combine_args(f, args, kwargs)
|
||||
if not isinstance(dynamic_shapes, dict):
|
||||
assert isinstance(dynamic_shapes, (tuple, list))
|
||||
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
|
||||
assoc_shapes(combined_args, dynamic_shapes)
|
||||
|
||||
constraints = []
|
||||
for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root:
|
||||
|
Reference in New Issue
Block a user