[export] refactor _Dim into Dim (#149891)

Summary: forward fix T218515233

Test Plan: test_export

Differential Revision: D71769231

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149891
Approved by: https://github.com/jingsh, https://github.com/angelayi
This commit is contained in:
Pian Pawakapan
2025-03-28 06:19:03 +00:00
committed by PyTorch MergeBot
parent f649ee73ce
commit 103bf64a3c
7 changed files with 137 additions and 128 deletions

View File

@ -790,7 +790,7 @@ API Reference
.. autofunction:: save
.. autofunction:: load
.. autofunction:: register_dataclass
.. autofunction:: torch.export.dynamic_shapes.Dim
.. autoclass:: torch.export.dynamic_shapes.Dim
.. autofunction:: torch.export.exported_program.default_decompositions
.. autofunction:: dims
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection

View File

@ -3852,7 +3852,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dynamic_shapes = (
{"k": {"k": dim}},
) # ValueError: Node type mismatch; expected <class 'list'>, but got .*_Dim.*.
) # ValueError: Node type mismatch; expected <class 'list'>, but got .*Dim.*.
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
re.escape(
@ -12362,7 +12362,7 @@ def forward(self, x):
self.assertExpectedInline(
_load_dynamic_shapes(spec, from_dict=False),
"""[[<class 'torch._export.serde.dynamic_shapes.dx'>]]""",
"""[[Dim('dx', min=4, max=16)]]""",
)
# check incorrect info in dims

View File

@ -450,7 +450,7 @@ class DynamoExporterTest(common_utils.TestCase):
)
dynamic_shapes = (
{0: torch.export.Dim("dim_x", min=3)}, # _Dim
{0: torch.export.Dim("dim_x", min=3)}, # Dim
[("custom_name_axis_ys_0",), (torch.export.Dim.AUTO,)], # custom name
{
"a": {0: torch.export.Dim.AUTO},

View File

@ -6,7 +6,6 @@ from torch._dynamo.exc import UserError, UserErrorType
from torch.export.dynamic_shapes import (
_check_dynamic_shapes,
_DerivedDim,
_Dim,
_DimHint,
_tree_map_with_path,
Dim,
@ -19,7 +18,7 @@ from .serialize import _dataclass_to_dict
@dataclasses.dataclass
class RootDim:
"""
This represents a _Dim object.
This represents a Dim object.
"""
min: int
@ -150,7 +149,7 @@ def _dump_dynamic_shapes(
return out
def _track_dim_from_dims(
val: Union[None, int, _DimHint, _Dim]
val: Union[None, int, _DimHint, Dim]
) -> Union[None, int, str]:
"""
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
@ -160,7 +159,7 @@ def _dump_dynamic_shapes(
if isinstance(val, _DimHint): # store enum as string
return val.__class__.__name__ + "." + val.type.name
assert isinstance(val, _Dim)
assert isinstance(val, Dim)
# track root dim
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
@ -297,7 +296,7 @@ def _load_dynamic_shapes(
def deserialize_shape(
val: Union[None, int, str]
) -> Union[None, int, _Dim, _DimHint]:
) -> Union[None, int, Dim, _DimHint]:
if val is None or isinstance(val, int):
return val
elif val == "_DimHint.AUTO":

View File

@ -70,19 +70,94 @@ class _DimHint:
return _DimHint(_DimHintType.STATIC)
class _Dim(type):
class Dim:
"""
Metaclass for :func:`Dim` types.
:func:`Dim` constructs a type analogous to a named symbolic integer with a range.
It can be used to describe multiple possible values of a dynamic tensor dimension.
Note that different dynamic dimensions of the same tensor, or of different tensors,
can be described by the same type.
Args:
name (str): Human-readable name for debugging.
min (Optional[int]): Minimum possible value of given symbol (inclusive)
max (Optional[int]): Maximum possible value of given symbol (inclusive)
Returns:
A type that can be used in dynamic shape specifications for tensors.
"""
AUTO = _DimHint.AUTO()
DYNAMIC = _DimHint.DYNAMIC()
STATIC = _DimHint.STATIC()
def __init__(
self, name: str, *, min: Optional[int] = None, max: Optional[int] = None
):
from torch.utils._sympy.numbers import int_oo
_min = 0 if min is None else min
_max = int_oo if max is None else max
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
self.__name__ = name
self.min = _min
self.max = _max
def __add__(self, other) -> "Dim":
# e.g., dim + 1
if type(other) is not int:
raise NotImplementedError(
f"Attempted to add {other} to {self.__name__}, where an integer was expected. "
"(Only increasing linear operations with integer coefficients are supported.)"
)
return self._derive(lambda x: x + other)
def __radd__(self, other) -> "Dim":
return self + other
def __sub__(self, other) -> "Dim":
# e.g., dim - 1
if type(other) is not int:
raise NotImplementedError(
f"Attempted to subtract {other} from {self.__name__}, where an integer was expected. "
"(Only increasing linear operations with integer coefficients are supported.)"
)
return self._derive(lambda x: x - other)
def __rsub__(self, other) -> "Dim":
raise NotImplementedError(
f"Attempted to negate {self.__name__}. "
"(Only increasing linear operations with integer coefficients are supported.)"
)
def __mul__(self, other) -> "Dim":
# e.g., dim * 2
if type(other) is not int or other <= 0:
raise NotImplementedError(
f"Attempted to multiply {other} with {self.__name__}, where a positive integer was expected. "
"(Only increasing linear operations with integer coefficients are supported.)"
)
return self._derive(lambda x: x * other)
def __rmul__(self, other) -> "Dim":
return self * other
def _derived_name(self, fn) -> str:
from sympy import sympify
return str(fn(sympify(self.__name__)))
def _derive(self, fn) -> "Dim":
return _DerivedDim(self._derived_name(fn), self, fn)
@staticmethod
def readable(name, min_, max_):
def _readable(name: str, min_: int, max_: int) -> str:
from torch.utils._sympy.numbers import int_oo
if min_ == 2:
min_ = None
min_ = None # type: ignore[assignment]
if max_ == int_oo:
max_ = None
max_ = None # type: ignore[assignment]
if min_ is None and max_ is None:
return f"Dim('{name}')"
if min_ is None:
@ -91,62 +166,25 @@ class _Dim(type):
return f"Dim('{name}', min={min_})"
return f"Dim('{name}', min={min_}, max={max_})"
def __add__(cls, other):
# e.g., dim + 1
if type(other) is not int:
raise NotImplementedError(
f"Attempted to add {other} to {cls.__name__}, where an integer was expected. "
"(Only increasing linear operations with integer coefficients are supported.)"
)
return cls._derive(lambda x: x + other)
def __radd__(cls, other):
return cls + other
def __sub__(cls, other):
# e.g., dim - 1
if type(other) is not int:
raise NotImplementedError(
f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. "
"(Only increasing linear operations with integer coefficients are supported.)"
)
return cls._derive(lambda x: x - other)
def __rsub__(cls, other):
raise NotImplementedError(
f"Attempted to negate {cls.__name__}. "
"(Only increasing linear operations with integer coefficients are supported.)"
)
def __mul__(cls, other):
# e.g., dim * 2
if type(other) is not int or other <= 0:
raise NotImplementedError(
f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. "
"(Only increasing linear operations with integer coefficients are supported.)"
)
return cls._derive(lambda x: x * other)
def __rmul__(cls, other):
return cls * other
def _derived_name(cls, fn):
from sympy import sympify
return str(fn(sympify(cls.__name__)))
def _derive(cls, fn):
return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn})
def __repr__(self):
return Dim._readable(self.__name__, self.min, self.max)
class _StaticDim(_Dim):
_Dim = Dim # TODO(pianpwk): remove after it's no longer internally breaking
class _StaticDim(Dim):
"""
Meta class for static :func:`Dim` types.
Class for static :func:`Dim` types.
This class is only for setting and checking static dim constraints,
and the user should never interact with it.
"""
def __init__(self, value: int):
self.__name__ = str(value)
self.value = value
@property
def min(self):
return self.value # type: ignore[attr-defined]
@ -156,9 +194,9 @@ class _StaticDim(_Dim):
return self.value # type: ignore[attr-defined]
class _DerivedDim(_Dim):
class _DerivedDim(Dim):
"""
Metaclass for derived :func:`Dim` types.
Class for derived :func:`Dim` types.
Currently we only support increasing linear expressions with integer coefficients.
In other words, a derived Dim can always be written in the form Ax + B, where
@ -172,6 +210,11 @@ class _DerivedDim(_Dim):
The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
"""
def __init__(self, name: str, root: Dim, fn: Callable):
self.__name__ = name
self.root = root
self.fn = fn
@property
def min(self):
# assume that self.fn is an increasing function
@ -218,50 +261,17 @@ class _DerivedDim(_Dim):
# As a consequence, roots are always regular Dims (i.e., not derived Dims).
return _DerivedDim(
self._derived_name(fn),
(int,),
{"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined]
self.root,
lambda x: fn(self.fn(x)),
)
class Dim(type):
"""
:func:`Dim` constructs a type analogous to a named symbolic integer with a range.
It can be used to describe multiple possible values of a dynamic tensor dimension.
Note that different dynamic dimensions of the same tensor, or of different tensors,
can be described by the same type.
Args:
name (str): Human-readable name for debugging.
min (Optional[int]): Minimum possible value of given symbol (inclusive)
max (Optional[int]): Maximum possible value of given symbol (inclusive)
Returns:
A type that can be used in dynamic shape specifications for tensors.
"""
AUTO = _DimHint.AUTO()
DYNAMIC = _DimHint.DYNAMIC()
STATIC = _DimHint.STATIC()
def __new__(
metacls, name: str, *, min: Optional[int] = None, max: Optional[int] = None
):
from torch.utils._sympy.numbers import int_oo
_min = 0 if min is None else min
_max = int_oo if max is None else max
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
dim = _Dim(name, (int,), {"min": _min, "max": _max})
dim.__module__ = getattr(
inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__"
)
return dim
def __repr__(self):
return self.__name__
def dims(
*names: str, min: Optional[int] = None, max: Optional[int] = None
) -> tuple[_Dim, ...]:
) -> tuple[Dim, ...]:
"""
Util to create multiple :func:`Dim` types.
@ -722,8 +732,8 @@ def _check_dynamic_shapes(
if dim.__name__ in bounds:
min_, max_ = bounds[dim.__name__]
if dim.min != min_ or dim.max != max_:
this_ = _Dim.readable(dim.__name__, min_, max_)
that_ = _Dim.readable(dim.__name__, dim.min, dim.max)
this_ = Dim._readable(dim.__name__, min_, max_)
that_ = Dim._readable(dim.__name__, dim.min, dim.max)
raise UserError(
UserErrorType.INVALID_INPUT,
f"Found different definitions {this_} and {that_} "
@ -735,7 +745,7 @@ def _check_dynamic_shapes(
def check_symbols(path, tensor, shape):
if isinstance(shape, dict):
for i, dim in shape.items():
if isinstance(dim, _Dim):
if isinstance(dim, Dim):
check_same_bounds(dim)
elif dim is None:
_warn_on_None_dynamic_shape_dimension()
@ -750,7 +760,7 @@ def _check_dynamic_shapes(
)
elif isinstance(shape, (tuple, list)):
for i, dim in enumerate(shape):
if isinstance(dim, _Dim):
if isinstance(dim, Dim):
check_same_bounds(dim)
elif dim is None:
_warn_on_None_dynamic_shape_dimension()
@ -911,7 +921,7 @@ def _process_dynamic_shapes(
),
)
else:
assert isinstance(dim, _Dim)
assert isinstance(dim, Dim)
constraint = _Constraint( # type: ignore[assignment]
id(tensor),
i,
@ -924,7 +934,7 @@ def _process_dynamic_shapes(
def update_symbols(path, tensor, shape):
def _create_static_dim(tensor, i, value):
return _StaticDim(str(value), (int,), {"value": value})
return _StaticDim(value)
# clean out decorators from user side, or previous export call
# we also delete these attributes in non_strict_utils.py/make_constraints()
@ -936,7 +946,7 @@ def _process_dynamic_shapes(
if isinstance(shape, dict):
for i, dim in shape.items():
if isinstance(dim, (int, _Dim)):
if isinstance(dim, (int, Dim)):
if isinstance(dim, int):
dim = _create_static_dim(tensor, i, dim)
constraint = to_constraint(dim, tensor, i)
@ -953,7 +963,7 @@ def _process_dynamic_shapes(
torch._dynamo.mark_static(tensor, i)
elif isinstance(shape, (tuple, list)):
for i, dim in enumerate(shape):
if isinstance(dim, (int, _Dim)):
if isinstance(dim, (int, Dim)):
if isinstance(dim, int):
dim = _create_static_dim(tensor, i, dim)
constraint = to_constraint(dim, tensor, i)
@ -1002,14 +1012,14 @@ def _get_dim_name_mapping(
name_to_dim = {}
for dim in tree_flatten(
dynamic_shapes,
is_leaf=lambda x: isinstance(x, _Dim),
is_leaf=lambda x: isinstance(x, Dim),
)[0]:
if dim is None:
# NOTE: this must denote a non-Tensor or automatic at this point.
continue
if isinstance(dim, int):
continue
elif isinstance(dim, _Dim):
elif isinstance(dim, Dim):
name_to_dim[dim.__name__] = dim
if isinstance(dim, _DerivedDim):
name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined]
@ -1092,7 +1102,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
# track derived dim roots
roots: set[str] = set()
for k, c in shape_fixes.items():
assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr))
assert isinstance(c, (int, Dim, _DerivedDim, sympy.Expr))
if isinstance(c, sympy.Expr): # check dim/derived dim expression
assert _is_supported_equivalence(c)
shape_fixes[k] = c

View File

@ -2784,8 +2784,8 @@ class DimConstraints:
) -> TypeGuard[torch.export.dynamic_shapes._DerivedDim]:
return isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
def _is_dim(self, dim: object) -> TypeGuard[torch.export.dynamic_shapes._Dim]:
return isinstance(dim, torch.export.dynamic_shapes._Dim) and not isinstance(
def _is_dim(self, dim: object) -> TypeGuard[torch.export.dynamic_shapes.Dim]:
return isinstance(dim, torch.export.dynamic_shapes.Dim) and not isinstance(
dim, torch.export.dynamic_shapes._DerivedDim
)

View File

@ -8,7 +8,7 @@ import warnings
from typing import Any, TYPE_CHECKING
import torch
from torch.export.dynamic_shapes import _Dim, _DimHint
from torch.export.dynamic_shapes import _DimHint, Dim
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
from torch.utils import _pytree
@ -157,7 +157,7 @@ def from_dynamic_shapes_to_dynamic_axes(
def _any_str_or_dim_in_dynamic_shapes(
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any],
) -> bool:
"""Check if there is any string or _Dim in the dynamic_shapes."""
"""Check if there is any string or Dim in the dynamic_shapes."""
flat_dynamic_shapes, _ = _flatten_dynamic_shapes_to_axes(dynamic_shapes)
# This indicates the dynamic_shapes includes something we don't support in axes, and it's flattened
# to itself. Otherwise, flat_dynamic_shapes should be a list of dict/list/tuple (or None).
@ -166,15 +166,15 @@ def _any_str_or_dim_in_dynamic_shapes(
for axes in flat_dynamic_shapes
):
return False
# both str and _Dim can provide custom names
# both str and Dim can provide custom names
for axes in flat_dynamic_shapes:
if isinstance(axes, dict):
for dim in axes.values():
if isinstance(dim, (str, _Dim)):
if isinstance(dim, (str, Dim)):
return True
elif isinstance(axes, (list, tuple)):
for dim in axes:
if isinstance(dim, (str, _Dim)):
if isinstance(dim, (str, Dim)):
return True
return False
@ -190,7 +190,7 @@ def convert_str_to_export_dim(
# for example: {"y": {0: "dim_0"}, "x": {1: "dim_1"}}
# to {"y": {0: Dim.AUTO}, "x": {1: Dim.AUTO}}
dynamic_shapes_with_export_dim: list[
list[_Dim | _DimHint | None] | dict[int, _Dim | _DimHint | None] | None
list[Dim | _DimHint | None] | dict[int, Dim | _DimHint | None] | None
] = []
flat_dynamic_shapes, tree_structure = _flatten_dynamic_shapes_to_axes(
dynamic_shapes
@ -199,7 +199,7 @@ def convert_str_to_export_dim(
if axes is None:
dynamic_shapes_with_export_dim.append(None)
elif isinstance(axes, dict):
converted_axes_dict: dict[int, _Dim | _DimHint | None] = {}
converted_axes_dict: dict[int, Dim | _DimHint | None] = {}
for axis, dim in axes.items():
if isinstance(dim, str):
converted_axes_dict[axis] = torch.export.Dim.AUTO
@ -207,7 +207,7 @@ def convert_str_to_export_dim(
converted_axes_dict[axis] = dim
dynamic_shapes_with_export_dim.append(converted_axes_dict)
elif isinstance(axes, (list, tuple)):
converted_axes_list: list[_Dim | _DimHint | None] = []
converted_axes_list: list[Dim | _DimHint | None] = []
for dim in axes:
if isinstance(dim, str):
converted_axes_list.append(torch.export.Dim.AUTO)
@ -292,9 +292,9 @@ def create_rename_mapping(
return rename_mapping
def _get_custom_axis_name(axis: _Dim | str) -> str:
def _get_custom_axis_name(axis: Dim | str) -> str:
"""Get the custom axis name from a torch.export.Dim."""
if isinstance(axis, _Dim):
if isinstance(axis, Dim):
return axis.__name__
return axis
@ -310,18 +310,18 @@ def _unflatten_dynamic_shapes_with_inputs_tree(
def _flatten_dynamic_shapes_to_axes(
dynamic_shapes: dict[str, Any | None] | tuple[Any, ...] | list[Any],
) -> tuple[list[Any], _pytree.TreeSpec]:
# If it's a dict/list/tuple with torch.export._Dim, we consider it's an axis to dim mapping
# If it's a dict/list/tuple with torch.export.Dim, we consider it's an axis to dim mapping
def is_axes(x) -> bool:
return (
isinstance(x, dict)
and all(
isinstance(k, int)
and (v is None or isinstance(v, (_Dim, _DimHint, str, int)))
and (v is None or isinstance(v, (Dim, _DimHint, str, int)))
for k, v in x.items()
)
) or (
isinstance(x, (list, tuple))
and all(v is None or isinstance(v, (_Dim, _DimHint, str, int)) for v in x)
and all(v is None or isinstance(v, (Dim, _DimHint, str, int)) for v in x)
)
return _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes)