mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] refactor _Dim into Dim (#149891)
Summary: forward fix T218515233 Test Plan: test_export Differential Revision: D71769231 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149891 Approved by: https://github.com/jingsh, https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
f649ee73ce
commit
103bf64a3c
@ -790,7 +790,7 @@ API Reference
|
||||
.. autofunction:: save
|
||||
.. autofunction:: load
|
||||
.. autofunction:: register_dataclass
|
||||
.. autofunction:: torch.export.dynamic_shapes.Dim
|
||||
.. autoclass:: torch.export.dynamic_shapes.Dim
|
||||
.. autofunction:: torch.export.exported_program.default_decompositions
|
||||
.. autofunction:: dims
|
||||
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection
|
||||
|
@ -3852,7 +3852,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
|
||||
dynamic_shapes = (
|
||||
{"k": {"k": dim}},
|
||||
) # ValueError: Node type mismatch; expected <class 'list'>, but got .*_Dim.*.
|
||||
) # ValueError: Node type mismatch; expected <class 'list'>, but got .*Dim.*.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
re.escape(
|
||||
@ -12362,7 +12362,7 @@ def forward(self, x):
|
||||
|
||||
self.assertExpectedInline(
|
||||
_load_dynamic_shapes(spec, from_dict=False),
|
||||
"""[[<class 'torch._export.serde.dynamic_shapes.dx'>]]""",
|
||||
"""[[Dim('dx', min=4, max=16)]]""",
|
||||
)
|
||||
|
||||
# check incorrect info in dims
|
||||
|
@ -450,7 +450,7 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||
)
|
||||
|
||||
dynamic_shapes = (
|
||||
{0: torch.export.Dim("dim_x", min=3)}, # _Dim
|
||||
{0: torch.export.Dim("dim_x", min=3)}, # Dim
|
||||
[("custom_name_axis_ys_0",), (torch.export.Dim.AUTO,)], # custom name
|
||||
{
|
||||
"a": {0: torch.export.Dim.AUTO},
|
||||
|
@ -6,7 +6,6 @@ from torch._dynamo.exc import UserError, UserErrorType
|
||||
from torch.export.dynamic_shapes import (
|
||||
_check_dynamic_shapes,
|
||||
_DerivedDim,
|
||||
_Dim,
|
||||
_DimHint,
|
||||
_tree_map_with_path,
|
||||
Dim,
|
||||
@ -19,7 +18,7 @@ from .serialize import _dataclass_to_dict
|
||||
@dataclasses.dataclass
|
||||
class RootDim:
|
||||
"""
|
||||
This represents a _Dim object.
|
||||
This represents a Dim object.
|
||||
"""
|
||||
|
||||
min: int
|
||||
@ -150,7 +149,7 @@ def _dump_dynamic_shapes(
|
||||
return out
|
||||
|
||||
def _track_dim_from_dims(
|
||||
val: Union[None, int, _DimHint, _Dim]
|
||||
val: Union[None, int, _DimHint, Dim]
|
||||
) -> Union[None, int, str]:
|
||||
"""
|
||||
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
|
||||
@ -160,7 +159,7 @@ def _dump_dynamic_shapes(
|
||||
if isinstance(val, _DimHint): # store enum as string
|
||||
return val.__class__.__name__ + "." + val.type.name
|
||||
|
||||
assert isinstance(val, _Dim)
|
||||
assert isinstance(val, Dim)
|
||||
|
||||
# track root dim
|
||||
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
|
||||
@ -297,7 +296,7 @@ def _load_dynamic_shapes(
|
||||
|
||||
def deserialize_shape(
|
||||
val: Union[None, int, str]
|
||||
) -> Union[None, int, _Dim, _DimHint]:
|
||||
) -> Union[None, int, Dim, _DimHint]:
|
||||
if val is None or isinstance(val, int):
|
||||
return val
|
||||
elif val == "_DimHint.AUTO":
|
||||
|
@ -70,19 +70,94 @@ class _DimHint:
|
||||
return _DimHint(_DimHintType.STATIC)
|
||||
|
||||
|
||||
class _Dim(type):
|
||||
class Dim:
|
||||
"""
|
||||
Metaclass for :func:`Dim` types.
|
||||
:func:`Dim` constructs a type analogous to a named symbolic integer with a range.
|
||||
It can be used to describe multiple possible values of a dynamic tensor dimension.
|
||||
Note that different dynamic dimensions of the same tensor, or of different tensors,
|
||||
can be described by the same type.
|
||||
|
||||
Args:
|
||||
name (str): Human-readable name for debugging.
|
||||
min (Optional[int]): Minimum possible value of given symbol (inclusive)
|
||||
max (Optional[int]): Maximum possible value of given symbol (inclusive)
|
||||
|
||||
Returns:
|
||||
A type that can be used in dynamic shape specifications for tensors.
|
||||
"""
|
||||
|
||||
AUTO = _DimHint.AUTO()
|
||||
DYNAMIC = _DimHint.DYNAMIC()
|
||||
STATIC = _DimHint.STATIC()
|
||||
|
||||
def __init__(
|
||||
self, name: str, *, min: Optional[int] = None, max: Optional[int] = None
|
||||
):
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
_min = 0 if min is None else min
|
||||
_max = int_oo if max is None else max
|
||||
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
|
||||
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
|
||||
self.__name__ = name
|
||||
self.min = _min
|
||||
self.max = _max
|
||||
|
||||
def __add__(self, other) -> "Dim":
|
||||
# e.g., dim + 1
|
||||
if type(other) is not int:
|
||||
raise NotImplementedError(
|
||||
f"Attempted to add {other} to {self.__name__}, where an integer was expected. "
|
||||
"(Only increasing linear operations with integer coefficients are supported.)"
|
||||
)
|
||||
return self._derive(lambda x: x + other)
|
||||
|
||||
def __radd__(self, other) -> "Dim":
|
||||
return self + other
|
||||
|
||||
def __sub__(self, other) -> "Dim":
|
||||
# e.g., dim - 1
|
||||
if type(other) is not int:
|
||||
raise NotImplementedError(
|
||||
f"Attempted to subtract {other} from {self.__name__}, where an integer was expected. "
|
||||
"(Only increasing linear operations with integer coefficients are supported.)"
|
||||
)
|
||||
return self._derive(lambda x: x - other)
|
||||
|
||||
def __rsub__(self, other) -> "Dim":
|
||||
raise NotImplementedError(
|
||||
f"Attempted to negate {self.__name__}. "
|
||||
"(Only increasing linear operations with integer coefficients are supported.)"
|
||||
)
|
||||
|
||||
def __mul__(self, other) -> "Dim":
|
||||
# e.g., dim * 2
|
||||
if type(other) is not int or other <= 0:
|
||||
raise NotImplementedError(
|
||||
f"Attempted to multiply {other} with {self.__name__}, where a positive integer was expected. "
|
||||
"(Only increasing linear operations with integer coefficients are supported.)"
|
||||
)
|
||||
return self._derive(lambda x: x * other)
|
||||
|
||||
def __rmul__(self, other) -> "Dim":
|
||||
return self * other
|
||||
|
||||
def _derived_name(self, fn) -> str:
|
||||
from sympy import sympify
|
||||
|
||||
return str(fn(sympify(self.__name__)))
|
||||
|
||||
def _derive(self, fn) -> "Dim":
|
||||
return _DerivedDim(self._derived_name(fn), self, fn)
|
||||
|
||||
@staticmethod
|
||||
def readable(name, min_, max_):
|
||||
def _readable(name: str, min_: int, max_: int) -> str:
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
if min_ == 2:
|
||||
min_ = None
|
||||
min_ = None # type: ignore[assignment]
|
||||
if max_ == int_oo:
|
||||
max_ = None
|
||||
max_ = None # type: ignore[assignment]
|
||||
if min_ is None and max_ is None:
|
||||
return f"Dim('{name}')"
|
||||
if min_ is None:
|
||||
@ -91,62 +166,25 @@ class _Dim(type):
|
||||
return f"Dim('{name}', min={min_})"
|
||||
return f"Dim('{name}', min={min_}, max={max_})"
|
||||
|
||||
def __add__(cls, other):
|
||||
# e.g., dim + 1
|
||||
if type(other) is not int:
|
||||
raise NotImplementedError(
|
||||
f"Attempted to add {other} to {cls.__name__}, where an integer was expected. "
|
||||
"(Only increasing linear operations with integer coefficients are supported.)"
|
||||
)
|
||||
return cls._derive(lambda x: x + other)
|
||||
|
||||
def __radd__(cls, other):
|
||||
return cls + other
|
||||
|
||||
def __sub__(cls, other):
|
||||
# e.g., dim - 1
|
||||
if type(other) is not int:
|
||||
raise NotImplementedError(
|
||||
f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. "
|
||||
"(Only increasing linear operations with integer coefficients are supported.)"
|
||||
)
|
||||
return cls._derive(lambda x: x - other)
|
||||
|
||||
def __rsub__(cls, other):
|
||||
raise NotImplementedError(
|
||||
f"Attempted to negate {cls.__name__}. "
|
||||
"(Only increasing linear operations with integer coefficients are supported.)"
|
||||
)
|
||||
|
||||
def __mul__(cls, other):
|
||||
# e.g., dim * 2
|
||||
if type(other) is not int or other <= 0:
|
||||
raise NotImplementedError(
|
||||
f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. "
|
||||
"(Only increasing linear operations with integer coefficients are supported.)"
|
||||
)
|
||||
return cls._derive(lambda x: x * other)
|
||||
|
||||
def __rmul__(cls, other):
|
||||
return cls * other
|
||||
|
||||
def _derived_name(cls, fn):
|
||||
from sympy import sympify
|
||||
|
||||
return str(fn(sympify(cls.__name__)))
|
||||
|
||||
def _derive(cls, fn):
|
||||
return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn})
|
||||
def __repr__(self):
|
||||
return Dim._readable(self.__name__, self.min, self.max)
|
||||
|
||||
|
||||
class _StaticDim(_Dim):
|
||||
_Dim = Dim # TODO(pianpwk): remove after it's no longer internally breaking
|
||||
|
||||
|
||||
class _StaticDim(Dim):
|
||||
"""
|
||||
Meta class for static :func:`Dim` types.
|
||||
Class for static :func:`Dim` types.
|
||||
|
||||
This class is only for setting and checking static dim constraints,
|
||||
and the user should never interact with it.
|
||||
"""
|
||||
|
||||
def __init__(self, value: int):
|
||||
self.__name__ = str(value)
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def min(self):
|
||||
return self.value # type: ignore[attr-defined]
|
||||
@ -156,9 +194,9 @@ class _StaticDim(_Dim):
|
||||
return self.value # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class _DerivedDim(_Dim):
|
||||
class _DerivedDim(Dim):
|
||||
"""
|
||||
Metaclass for derived :func:`Dim` types.
|
||||
Class for derived :func:`Dim` types.
|
||||
|
||||
Currently we only support increasing linear expressions with integer coefficients.
|
||||
In other words, a derived Dim can always be written in the form Ax + B, where
|
||||
@ -172,6 +210,11 @@ class _DerivedDim(_Dim):
|
||||
The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, root: Dim, fn: Callable):
|
||||
self.__name__ = name
|
||||
self.root = root
|
||||
self.fn = fn
|
||||
|
||||
@property
|
||||
def min(self):
|
||||
# assume that self.fn is an increasing function
|
||||
@ -218,50 +261,17 @@ class _DerivedDim(_Dim):
|
||||
# As a consequence, roots are always regular Dims (i.e., not derived Dims).
|
||||
return _DerivedDim(
|
||||
self._derived_name(fn),
|
||||
(int,),
|
||||
{"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined]
|
||||
self.root,
|
||||
lambda x: fn(self.fn(x)),
|
||||
)
|
||||
|
||||
|
||||
class Dim(type):
|
||||
"""
|
||||
:func:`Dim` constructs a type analogous to a named symbolic integer with a range.
|
||||
It can be used to describe multiple possible values of a dynamic tensor dimension.
|
||||
Note that different dynamic dimensions of the same tensor, or of different tensors,
|
||||
can be described by the same type.
|
||||
|
||||
Args:
|
||||
name (str): Human-readable name for debugging.
|
||||
min (Optional[int]): Minimum possible value of given symbol (inclusive)
|
||||
max (Optional[int]): Maximum possible value of given symbol (inclusive)
|
||||
|
||||
Returns:
|
||||
A type that can be used in dynamic shape specifications for tensors.
|
||||
"""
|
||||
|
||||
AUTO = _DimHint.AUTO()
|
||||
DYNAMIC = _DimHint.DYNAMIC()
|
||||
STATIC = _DimHint.STATIC()
|
||||
|
||||
def __new__(
|
||||
metacls, name: str, *, min: Optional[int] = None, max: Optional[int] = None
|
||||
):
|
||||
from torch.utils._sympy.numbers import int_oo
|
||||
|
||||
_min = 0 if min is None else min
|
||||
_max = int_oo if max is None else max
|
||||
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
|
||||
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
|
||||
dim = _Dim(name, (int,), {"min": _min, "max": _max})
|
||||
dim.__module__ = getattr(
|
||||
inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__"
|
||||
)
|
||||
return dim
|
||||
def __repr__(self):
|
||||
return self.__name__
|
||||
|
||||
|
||||
def dims(
|
||||
*names: str, min: Optional[int] = None, max: Optional[int] = None
|
||||
) -> tuple[_Dim, ...]:
|
||||
) -> tuple[Dim, ...]:
|
||||
"""
|
||||
Util to create multiple :func:`Dim` types.
|
||||
|
||||
@ -722,8 +732,8 @@ def _check_dynamic_shapes(
|
||||
if dim.__name__ in bounds:
|
||||
min_, max_ = bounds[dim.__name__]
|
||||
if dim.min != min_ or dim.max != max_:
|
||||
this_ = _Dim.readable(dim.__name__, min_, max_)
|
||||
that_ = _Dim.readable(dim.__name__, dim.min, dim.max)
|
||||
this_ = Dim._readable(dim.__name__, min_, max_)
|
||||
that_ = Dim._readable(dim.__name__, dim.min, dim.max)
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
f"Found different definitions {this_} and {that_} "
|
||||
@ -735,7 +745,7 @@ def _check_dynamic_shapes(
|
||||
def check_symbols(path, tensor, shape):
|
||||
if isinstance(shape, dict):
|
||||
for i, dim in shape.items():
|
||||
if isinstance(dim, _Dim):
|
||||
if isinstance(dim, Dim):
|
||||
check_same_bounds(dim)
|
||||
elif dim is None:
|
||||
_warn_on_None_dynamic_shape_dimension()
|
||||
@ -750,7 +760,7 @@ def _check_dynamic_shapes(
|
||||
)
|
||||
elif isinstance(shape, (tuple, list)):
|
||||
for i, dim in enumerate(shape):
|
||||
if isinstance(dim, _Dim):
|
||||
if isinstance(dim, Dim):
|
||||
check_same_bounds(dim)
|
||||
elif dim is None:
|
||||
_warn_on_None_dynamic_shape_dimension()
|
||||
@ -911,7 +921,7 @@ def _process_dynamic_shapes(
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert isinstance(dim, _Dim)
|
||||
assert isinstance(dim, Dim)
|
||||
constraint = _Constraint( # type: ignore[assignment]
|
||||
id(tensor),
|
||||
i,
|
||||
@ -924,7 +934,7 @@ def _process_dynamic_shapes(
|
||||
|
||||
def update_symbols(path, tensor, shape):
|
||||
def _create_static_dim(tensor, i, value):
|
||||
return _StaticDim(str(value), (int,), {"value": value})
|
||||
return _StaticDim(value)
|
||||
|
||||
# clean out decorators from user side, or previous export call
|
||||
# we also delete these attributes in non_strict_utils.py/make_constraints()
|
||||
@ -936,7 +946,7 @@ def _process_dynamic_shapes(
|
||||
|
||||
if isinstance(shape, dict):
|
||||
for i, dim in shape.items():
|
||||
if isinstance(dim, (int, _Dim)):
|
||||
if isinstance(dim, (int, Dim)):
|
||||
if isinstance(dim, int):
|
||||
dim = _create_static_dim(tensor, i, dim)
|
||||
constraint = to_constraint(dim, tensor, i)
|
||||
@ -953,7 +963,7 @@ def _process_dynamic_shapes(
|
||||
torch._dynamo.mark_static(tensor, i)
|
||||
elif isinstance(shape, (tuple, list)):
|
||||
for i, dim in enumerate(shape):
|
||||
if isinstance(dim, (int, _Dim)):
|
||||
if isinstance(dim, (int, Dim)):
|
||||
if isinstance(dim, int):
|
||||
dim = _create_static_dim(tensor, i, dim)
|
||||
constraint = to_constraint(dim, tensor, i)
|
||||
@ -1002,14 +1012,14 @@ def _get_dim_name_mapping(
|
||||
name_to_dim = {}
|
||||
for dim in tree_flatten(
|
||||
dynamic_shapes,
|
||||
is_leaf=lambda x: isinstance(x, _Dim),
|
||||
is_leaf=lambda x: isinstance(x, Dim),
|
||||
)[0]:
|
||||
if dim is None:
|
||||
# NOTE: this must denote a non-Tensor or automatic at this point.
|
||||
continue
|
||||
if isinstance(dim, int):
|
||||
continue
|
||||
elif isinstance(dim, _Dim):
|
||||
elif isinstance(dim, Dim):
|
||||
name_to_dim[dim.__name__] = dim
|
||||
if isinstance(dim, _DerivedDim):
|
||||
name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined]
|
||||
@ -1092,7 +1102,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
|
||||
# track derived dim roots
|
||||
roots: set[str] = set()
|
||||
for k, c in shape_fixes.items():
|
||||
assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr))
|
||||
assert isinstance(c, (int, Dim, _DerivedDim, sympy.Expr))
|
||||
if isinstance(c, sympy.Expr): # check dim/derived dim expression
|
||||
assert _is_supported_equivalence(c)
|
||||
shape_fixes[k] = c
|
||||
|
@ -2784,8 +2784,8 @@ class DimConstraints:
|
||||
) -> TypeGuard[torch.export.dynamic_shapes._DerivedDim]:
|
||||
return isinstance(dim, torch.export.dynamic_shapes._DerivedDim)
|
||||
|
||||
def _is_dim(self, dim: object) -> TypeGuard[torch.export.dynamic_shapes._Dim]:
|
||||
return isinstance(dim, torch.export.dynamic_shapes._Dim) and not isinstance(
|
||||
def _is_dim(self, dim: object) -> TypeGuard[torch.export.dynamic_shapes.Dim]:
|
||||
return isinstance(dim, torch.export.dynamic_shapes.Dim) and not isinstance(
|
||||
dim, torch.export.dynamic_shapes._DerivedDim
|
||||
)
|
||||
|
||||
|
@ -8,7 +8,7 @@ import warnings
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.export.dynamic_shapes import _Dim, _DimHint
|
||||
from torch.export.dynamic_shapes import _DimHint, Dim
|
||||
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
|
||||
from torch.utils import _pytree
|
||||
|
||||
@ -157,7 +157,7 @@ def from_dynamic_shapes_to_dynamic_axes(
|
||||
def _any_str_or_dim_in_dynamic_shapes(
|
||||
dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any],
|
||||
) -> bool:
|
||||
"""Check if there is any string or _Dim in the dynamic_shapes."""
|
||||
"""Check if there is any string or Dim in the dynamic_shapes."""
|
||||
flat_dynamic_shapes, _ = _flatten_dynamic_shapes_to_axes(dynamic_shapes)
|
||||
# This indicates the dynamic_shapes includes something we don't support in axes, and it's flattened
|
||||
# to itself. Otherwise, flat_dynamic_shapes should be a list of dict/list/tuple (or None).
|
||||
@ -166,15 +166,15 @@ def _any_str_or_dim_in_dynamic_shapes(
|
||||
for axes in flat_dynamic_shapes
|
||||
):
|
||||
return False
|
||||
# both str and _Dim can provide custom names
|
||||
# both str and Dim can provide custom names
|
||||
for axes in flat_dynamic_shapes:
|
||||
if isinstance(axes, dict):
|
||||
for dim in axes.values():
|
||||
if isinstance(dim, (str, _Dim)):
|
||||
if isinstance(dim, (str, Dim)):
|
||||
return True
|
||||
elif isinstance(axes, (list, tuple)):
|
||||
for dim in axes:
|
||||
if isinstance(dim, (str, _Dim)):
|
||||
if isinstance(dim, (str, Dim)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -190,7 +190,7 @@ def convert_str_to_export_dim(
|
||||
# for example: {"y": {0: "dim_0"}, "x": {1: "dim_1"}}
|
||||
# to {"y": {0: Dim.AUTO}, "x": {1: Dim.AUTO}}
|
||||
dynamic_shapes_with_export_dim: list[
|
||||
list[_Dim | _DimHint | None] | dict[int, _Dim | _DimHint | None] | None
|
||||
list[Dim | _DimHint | None] | dict[int, Dim | _DimHint | None] | None
|
||||
] = []
|
||||
flat_dynamic_shapes, tree_structure = _flatten_dynamic_shapes_to_axes(
|
||||
dynamic_shapes
|
||||
@ -199,7 +199,7 @@ def convert_str_to_export_dim(
|
||||
if axes is None:
|
||||
dynamic_shapes_with_export_dim.append(None)
|
||||
elif isinstance(axes, dict):
|
||||
converted_axes_dict: dict[int, _Dim | _DimHint | None] = {}
|
||||
converted_axes_dict: dict[int, Dim | _DimHint | None] = {}
|
||||
for axis, dim in axes.items():
|
||||
if isinstance(dim, str):
|
||||
converted_axes_dict[axis] = torch.export.Dim.AUTO
|
||||
@ -207,7 +207,7 @@ def convert_str_to_export_dim(
|
||||
converted_axes_dict[axis] = dim
|
||||
dynamic_shapes_with_export_dim.append(converted_axes_dict)
|
||||
elif isinstance(axes, (list, tuple)):
|
||||
converted_axes_list: list[_Dim | _DimHint | None] = []
|
||||
converted_axes_list: list[Dim | _DimHint | None] = []
|
||||
for dim in axes:
|
||||
if isinstance(dim, str):
|
||||
converted_axes_list.append(torch.export.Dim.AUTO)
|
||||
@ -292,9 +292,9 @@ def create_rename_mapping(
|
||||
return rename_mapping
|
||||
|
||||
|
||||
def _get_custom_axis_name(axis: _Dim | str) -> str:
|
||||
def _get_custom_axis_name(axis: Dim | str) -> str:
|
||||
"""Get the custom axis name from a torch.export.Dim."""
|
||||
if isinstance(axis, _Dim):
|
||||
if isinstance(axis, Dim):
|
||||
return axis.__name__
|
||||
return axis
|
||||
|
||||
@ -310,18 +310,18 @@ def _unflatten_dynamic_shapes_with_inputs_tree(
|
||||
def _flatten_dynamic_shapes_to_axes(
|
||||
dynamic_shapes: dict[str, Any | None] | tuple[Any, ...] | list[Any],
|
||||
) -> tuple[list[Any], _pytree.TreeSpec]:
|
||||
# If it's a dict/list/tuple with torch.export._Dim, we consider it's an axis to dim mapping
|
||||
# If it's a dict/list/tuple with torch.export.Dim, we consider it's an axis to dim mapping
|
||||
def is_axes(x) -> bool:
|
||||
return (
|
||||
isinstance(x, dict)
|
||||
and all(
|
||||
isinstance(k, int)
|
||||
and (v is None or isinstance(v, (_Dim, _DimHint, str, int)))
|
||||
and (v is None or isinstance(v, (Dim, _DimHint, str, int)))
|
||||
for k, v in x.items()
|
||||
)
|
||||
) or (
|
||||
isinstance(x, (list, tuple))
|
||||
and all(v is None or isinstance(v, (_Dim, _DimHint, str, int)) for v in x)
|
||||
and all(v is None or isinstance(v, (Dim, _DimHint, str, int)) for v in x)
|
||||
)
|
||||
|
||||
return _pytree.tree_flatten(dynamic_shapes, is_leaf=is_axes)
|
||||
|
Reference in New Issue
Block a user