dynamic shapes builder API (#124898)

This PR introduces a new way of building `dynamic_shapes` for export. The idea is to build up a mapping from input tensors to the dynamic shapes that should be assigned to their corresponding fake tensors.

This mapping is automatically converted to the current form of `dynamic_shapes`, which must exactly match the structure of inputs. We do this by using pytree utils.

With the current `dynamic_shapes`, we had to be careful about user-defined classes that are registered with pytree, since  such classes are not necessarily polymorphic containers; they may be fine containing tensors, but not dynamic shapes. Thus we had decided to allow input instances of such classes to be associated with dynamic shapes in flattened form. This decision needs to be mirrored in this PR as well. To make it easier to keep these code paths in sync, we refactor the current recursive procedure for associating inputs with dynamic shapes to use the same pytree utils. This needs minor fixes to a few tests where `dynamic_shapes` were not exactly matching the structure of inputs.

Differential Revision: D56551992

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124898
Approved by: https://github.com/zhxchen17
This commit is contained in:
Avik Chaudhuri
2024-04-30 03:59:49 +00:00
committed by PyTorch MergeBot
parent 31801918e9
commit e7846447e0
6 changed files with 254 additions and 75 deletions

View File

@ -1 +1 @@
58a412cb271a3f98ae2e01fd1d24bdbb66645d4e
73b915b55d96553a0e370b2bab01f47b8c2a9e7c

View File

@ -685,6 +685,10 @@ API Reference
.. autofunction:: register_dataclass
.. autofunction:: torch.export.dynamic_shapes.Dim
.. autofunction:: dims
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection
.. automethod:: dynamic_shapes
.. autoclass:: Constraint
.. autoclass:: ExportedProgram

View File

@ -10,6 +10,7 @@ import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from re import escape
from typing import Dict, List
import torch
import torch._dynamo as torchdynamo
@ -124,6 +125,13 @@ def foo_functional(x):
return a.cos()
@dataclass
class Inp:
x: Tensor
y: List[Tensor]
z: Dict[str, Tensor]
NON_STRICT_SUFFIX = "_non_strict"
RETRACEABILITY_SUFFIX = "_retraceability"
@ -1090,6 +1098,86 @@ class TestExport(TestCase):
self.assertEqual(range_lower_bounds, [1, 2])
self.assertEqual(range_upper_bounds, [2, 3])
def test_dynamic_shapes_builder_basic(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
return x + y[0] + z["k"]
m = M()
x = torch.randn(4)
y = [torch.randn(4)]
z = {"k": torch.randn(4)}
args = (x, y, z)
shapes_collection = torch.export.ShapesCollection()
dim = torch.export.Dim("dim", max=10)
shapes_collection[x] = (dim,)
shapes_collection[y[0]] = (dim,)
shapes_collection[z["k"]] = (dim,)
ep = export(m, args, dynamic_shapes=shapes_collection)
sym = next(iter(ep.range_constraints.keys()))
for node in ep.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
def test_dynamic_shapes_builder_kwargs(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
return x + y[0] + z["k"]
m = M()
x = torch.randn(4)
y = [torch.randn(4)]
z = {"k": torch.randn(4)}
args = (x,)
kwargs = {"z": z, "y": y}
shapes_collection = torch.export.ShapesCollection()
dim = torch.export.Dim("dim", max=10)
shapes_collection[x] = (dim,)
shapes_collection[y[0]] = (dim,)
shapes_collection[z["k"]] = (dim,)
ep = export(m, args, kwargs=kwargs, dynamic_shapes=shapes_collection)
sym = next(iter(ep.range_constraints.keys()))
for node in ep.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
# retracing doesn't seem to like dataclass registration,
# raising a dynamo error in fx_pytree.tree_flatten_spec
@testing.expectedFailureRetraceability
def test_dynamic_shapes_builder_pytree(self):
torch.export.register_dataclass(
Inp,
serialized_type_name="test_dynamic_shapes_builder_pytree.Inp",
)
class M(torch.nn.Module):
def forward(self, inp: Inp):
return inp.x + inp.y[0] + inp.z["k"]
m = M()
x = torch.randn(4)
y = [torch.randn(4)]
z = {"k": torch.randn(4)}
args = (Inp(x, y, z),)
shapes_collection = torch.export.ShapesCollection()
dim = torch.export.Dim("dim", max=10)
shapes_collection[x] = (dim,)
shapes_collection[y[0]] = (dim,)
shapes_collection[z["k"]] = (dim,)
ep = export(m, args, dynamic_shapes=shapes_collection.dynamic_shapes(m, args))
sym = next(iter(ep.range_constraints.keys()))
for node in ep.graph.nodes:
if node.op == "placeholder":
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
def test_raise_user_error_when_guard_on_data_dependent_operation(self):
class M(torch.nn.Module):
def forward(self, x):
@ -1537,7 +1625,7 @@ class TestExport(TestCase):
return torch.matmul(inputs[0], inputs[1])
foo = Foo()
inputs = ((torch.randn(10, 2, 3), torch.randn(10, 3, 4)),)
inputs = ([torch.randn(10, 2, 3), torch.randn(10, 3, 4)],)
batch = Dim("batch")
efoo = export(
foo, inputs, dynamic_shapes={"inputs": [{0: batch} for _ in range(2)]}
@ -1553,6 +1641,11 @@ class TestExport(TestCase):
self.assertEqual(efoo.module()(*inputs).shape, foo(*inputs).shape)
# pass dynamic shapes of inputs [dataclass]
# TODO(avik): This part of the test should have failed both serde and retracing
# but these failures are hidden because of the local import of `export` in this test.
# The serde failure is benign, and easily avoided by moving the dataclass definition
# to the top-level. OTOH the retracing failure needs further investigation.
@dataclass
class DataClass:
a: Tensor

View File

@ -61,7 +61,7 @@ __all__ = [
]
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim
from .dynamic_shapes import Constraint, Dim, dims, dynamic_dim, ShapesCollection
from .exported_program import ExportedProgram, ModuleCallEntry, ModuleCallSignature
from .graph_signature import ExportBackwardSignature, ExportGraphSignature
from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule

View File

@ -922,8 +922,6 @@ def _export(
Returns:
An ExportedProgram containing the traced method.
"""
from .dynamic_shapes import _process_dynamic_shapes
if not isinstance(args, tuple):
raise UserError(
UserErrorType.INVALID_INPUT,
@ -940,7 +938,8 @@ def _export(
_EXPORT_FLAGS = flags
kwargs = kwargs or {}
_process_dynamic_shapes(mod, args, kwargs, dynamic_shapes) # TODO(avik): remove
if isinstance(dynamic_shapes, torch.export.ShapesCollection):
dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs)
constant_attrs = _gather_constant_attrs(mod)

View File

@ -8,7 +8,7 @@ from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch.utils._pytree import SUPPORTED_NODES
from torch.utils._pytree import _get_node_type, BUILTIN_TYPES, SUPPORTED_NODES, tree_map
from .exported_program import ExportedProgram
@ -553,72 +553,151 @@ def _process_equalities(
derived_equalities.append((source, root, fn))
def _tree_map(
func: Callable[..., Any],
tree: Any,
*dynamic_shapes: Any,
) -> Any:
"""
Customized tree_map for mapping pytrees to dynamic_shapes.
For built-in types (e.g., standard collections) this behaves exactly like tree_map.
OTOH for a user-defined class C registered with pytree, we cannot assume that a C
containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not
be a polymorphic container). In that case we use the flattened form of C instead.
Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes).
Args:
func: function to apply to each (int, float, str, bool, None, torch.Tensor)
tree: input pytree
dynamic_shapes: zero or more (typically one) dynamic_shapes to match
Returns:
output pytree mapping func to each (int, float, str, bool, None, torch.Tensor)
"""
def is_leaf(t):
# BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types
# registered with pytree. Types *not* in BUILTIN_TYPES include primitive types
# (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES,
# as well as user-defined classes registered with pytree, which are.
return _get_node_type(t) not in BUILTIN_TYPES
def f(t, *dynamic_shapes):
typ = _get_node_type(t)
# typ is not in BUILTIN_TYPES
if typ in SUPPORTED_NODES:
# thus typ is a user-defined class registered with pytree,
# in which case flatten and recurse
return tree_map(
f,
SUPPORTED_NODES[typ].flatten_fn(t)[0],
*dynamic_shapes,
is_leaf=is_leaf,
)
else:
return func(t, *dynamic_shapes)
return tree_map(f, tree, *dynamic_shapes, is_leaf=is_leaf)
def _combine_args(f, args, kwargs):
# combine args and kwargs following the signature of f, as it happens
# in the body of f when called with *args, **kwargs
if isinstance(f, ExportedProgram):
f = f.module()
signature = (
inspect.signature(f.forward)
if isinstance(f, torch.nn.Module)
else inspect.signature(f)
)
kwargs = kwargs if kwargs is not None else {}
return signature.bind(*args, **kwargs).arguments
class ShapesCollection:
"""
Builder for dynamic_shapes.
Used to assign dynamic shape specifications to tensors that appear in inputs.
Example::
args = ({"x": tensor_x, "others": [tensor_y, tensor_z]})
dim = torch.export.Dim(...)
dynamic_shapes = torch.export.ShapesCollection()
dynamic_shapes[tensor_x] = (dim, dim + 1, 8)
dynamic_shapes[tensor_y] = {0: dim * 2}
# This is equivalent to the following (now auto-generated):
# dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]}
torch.export(..., args, dynamic_shapes=dynamic_shapes)
"""
def __init__(self):
self._shapes = {}
def __setitem__(self, t, shape):
assert isinstance(
t, torch.Tensor
), f"Cannot assign shape to non-tensor type {type(t)}"
# TODO(avik): check that shape is indeed a Shape
t_id = id(t)
if t_id in self._shapes:
_shape = self._shapes[t_id]
assert (
shape == _shape
), f"Shapes assigned to tensor do not match: expected {_shape}, got {shape}"
else:
self._shapes[id(t)] = shape
def __getitem__(self, t):
t_id = id(t)
if t_id in self._shapes:
return self._shapes[t_id]
else:
return None
def __len__(self):
return len(self._shapes)
def dynamic_shapes(self, m, args, kwargs=None):
"""
Generate dynamic_shapes.
"""
t_ids = set()
def find_shape(t):
t_id = id(t)
if t_id in self._shapes:
t_ids.add(t_id)
return self._shapes[t_id]
else:
return None
combined_args = _combine_args(m, args, kwargs)
dynamic_shapes = _tree_map(find_shape, combined_args)
if any(t_id not in t_ids for t_id in self._shapes):
raise ValueError(
"Some tensors that were assigned shapes were not found in args. "
"Maybe such tensors were copied when passing them as args? "
"Maybe such tensors are contained in classes that were not registered with pytree?"
)
return dynamic_shapes
def _process_dynamic_shapes(
f: Callable,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
) -> Optional[List[Constraint]]:
from collections.abc import Mapping, Sequence
from torch._dynamo.exc import UserError, UserErrorType
if dynamic_shapes is None or len(dynamic_shapes) == 0:
return None
kwargs = kwargs if kwargs is not None else {}
def tree_zip(combined_args, dynamic_shapes):
if isinstance(combined_args, (tuple, list)):
if not isinstance(dynamic_shapes, Sequence):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be a Sequence, "
f"got {dynamic_shapes} instead",
)
if len(combined_args) != len(dynamic_shapes):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected {dynamic_shapes} to have {len(combined_args)} items",
)
for i, shape in enumerate(dynamic_shapes):
yield from tree_zip(combined_args[i], shape)
elif isinstance(combined_args, dict):
if not isinstance(dynamic_shapes, Mapping):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be a Mapping, "
f"got {dynamic_shapes} instead",
)
if len(combined_args) != len(dynamic_shapes):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected {dynamic_shapes} to have {len(combined_args)} items",
)
for k, shape in dynamic_shapes.items():
yield from tree_zip(combined_args[k], shape)
elif type(combined_args) in SUPPORTED_NODES:
if not isinstance(dynamic_shapes, Sequence):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a user-registered class (e.g., "
f"{type(combined_args)}) to be a Sequence that matches the "
f"flattened structure, but got {dynamic_shapes} instead",
)
yield from tree_zip(
SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0],
dynamic_shapes,
)
elif isinstance(combined_args, torch.Tensor):
yield (combined_args, dynamic_shapes)
else:
if dynamic_shapes is not None:
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be None, "
f"got {dynamic_shapes} instead",
)
# map of Dim names representing input shape dimensions to constraints on them
symbols: Dict[str, List[Constraint]] = defaultdict(list)
# track roots that do not directly represent input shape dimensions
@ -765,21 +844,25 @@ def _process_dynamic_shapes(
f"Unexpected dynamic_shape {shape} of Tensor, " "try None instead",
)
import inspect
if isinstance(f, ExportedProgram):
f = f.module()
signature = (
inspect.signature(f.forward)
if isinstance(f, torch.nn.Module)
else inspect.signature(f)
def assoc_shapes(combined_args, dynamic_shapes):
def assoc_shape(t, dynamic_shape):
if isinstance(t, torch.Tensor):
update_symbols(t, dynamic_shape)
else:
if dynamic_shape is not None:
raise UserError(
UserErrorType.INVALID_INPUT,
f"Cannot associate shape {dynamic_shape} to non-tensor type {type(t)}, "
f"expected None",
)
combined_args = signature.bind(*args, **kwargs).arguments
# This means user didn't specify dynamic shapes with argument names.
combined_args = combined_args if isinstance(dynamic_shapes, Mapping) else list(combined_args.values()) # type: ignore[assignment]
for tensor, shape in tree_zip(combined_args, dynamic_shapes):
update_symbols(tensor, shape)
_tree_map(assoc_shape, combined_args, dynamic_shapes)
combined_args = _combine_args(f, args, kwargs)
if not isinstance(dynamic_shapes, dict):
assert isinstance(dynamic_shapes, (tuple, list))
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
assoc_shapes(combined_args, dynamic_shapes)
constraints = []
for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: