Files
pytorch/torch/_export/serde/dynamic_shapes.py
Maggie Moss 4ab847bbc7 Pyrefly suppressions 4/n (#164615)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: uncomment lines in the pyrefly.toml file
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/356645cf8cfe33123d9a27f23b30f7b1

after:

0 errors (2,753 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164615
Approved by: https://github.com/oulgen
2025-10-06 16:14:36 +00:00

325 lines
11 KiB
Python

import dataclasses
from typing import Any, Optional, Union
import torch
from torch._dynamo.exc import UserError, UserErrorType
from torch.export.dynamic_shapes import (
_check_dynamic_shapes,
_DerivedDim,
_DimHint,
_tree_map_with_path,
Dim,
)
from torch.utils._pytree import tree_map
from .serialize import _dataclass_to_dict
@dataclasses.dataclass
class RootDim:
"""
This represents a Dim object.
"""
min: int
max: Union[int, None]
derived: list[str]
@dataclasses.dataclass
class DynamicShapesSpec:
"""
This stores a dynamic_shapes spec for de/serialization.
"""
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None]
dims: dict[str, RootDim]
def _postprocess_serialized_shapes(
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
dims: dict[str, dict[str, Union[int, list[str], None]]],
to_dict: Optional[bool] = False,
) -> Union[DynamicShapesSpec, dict[str, Any]]:
"""
Sorts dims and dumps to dictionary format.
"""
from torch.utils._sympy.numbers import int_oo
dims = {
k: RootDim(
min=v["min"], # type: ignore[arg-type]
max=None if v["max"] is int_oo else v["max"], # type: ignore[arg-type]
derived=sorted(v["derived"]), # type: ignore[arg-type]
)
for k, v in sorted(dims.items())
}
# pyrefly: ignore # bad-argument-type
spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
if to_dict:
return _dataclass_to_dict(spec)
else:
return spec
def _dump_dynamic_shapes(
dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None],
args: tuple[Any],
kwargs: Optional[dict[str, Any]] = None,
to_dict: Optional[bool] = False,
) -> Union[DynamicShapesSpec, dict[str, Any]]:
"""
Utility function for dynamic shapes serialization, serializing a dynamic_shapes spec.
Returns a DynamicShapesSpec dataclass containing 2 fields, "dynamic_shapes" and "dims".
Uses args & kwargs to distinguish between tensor-level and dim-level specs (only for Nones).
dynamic_shapes: A pytree structure mirroring the dynamic_shapes input to export():
- Each tensor input is represented with a list of values, non-tensor inputs with None.
- dynamic dimensions (i.e. symbols) in tensors and Dim enums are represented with strings.
- static dimensions are represented with ints.
dims: A dictionary mapping each symbol name to the min/max range and derived dim names.
For example:
```
dx = Dim("dx", min=4, max=16)
dy = dx + 1
inputs = (
[
torch.randn(4, 4),
torch.randn(5, 4),
],
torch.randn(4),
torch.randn(4, 4),
"hello",
)
dynamic_shapes = {
"a": [
(dx, 4),
(dy, 4),
],
"b": (Dim.STATIC,),
"c": None,
"d": None,
}
out = _dump_dynamic_shapes(dynamic_shapes, inputs, to_dict=True)
```
would generate the following output:
```
{
"dynamic_shapes": (
[
["dx", 4],
["dx + 1", 4],
],
["_DimHint.STATIC"],
["_DimHint.STATIC", "_DimHint.STATIC"],
None,
),
"dims": {
"dx": {
"min": 4,
"max": 16,
"derived": ["dx + 1"],
},
},
}
```
"""
dims: dict[str, dict[str, Any]] = {}
def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
"""
Helps standardize the dynamic_shapes tree structure we serialize,
returning lists for each tensor shape, handling tensor-level Nones.
"""
if not isinstance(tensor, torch.Tensor):
return None
if shape is None:
return [Dim.STATIC] * len(tensor.shape)
out = []
if isinstance(shape, dict):
for i, s in enumerate(tensor.shape):
out.append(s if shape.get(i) is None else shape.get(i))
else:
assert isinstance(shape, (tuple, list))
for i, s in enumerate(tensor.shape):
out.append(s if shape[i] is None else shape[i])
return out
def _track_dim_from_dims(
val: Union[None, int, _DimHint, Dim],
) -> Union[None, int, str]:
"""
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
"""
if val is None or isinstance(val, int): # non-tensor input or static
return val
if isinstance(val, _DimHint): # store enum as string
return val.__class__.__name__ + "." + val.type.name
assert isinstance(val, Dim)
# track root dim
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
if root.__name__ not in dims:
dims[root.__name__] = {
"min": root.min, # type: ignore[attr-defined,union-attr]
"max": root.max, # type: ignore[attr-defined,union-attr]
"derived": set(),
}
# track derived dims
if isinstance(val, _DerivedDim):
dims[root.__name__]["derived"].add(val.__name__)
return val.__name__
if dynamic_shapes is None:
return {"dynamic_shapes": None, "dims": {}}
# convert to tuple of specs, for each arg/kwarg
kwargs = kwargs or {}
if isinstance(dynamic_shapes, dict):
dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
# pyrefly: ignore # bad-assignment, bad-argument-type
dynamic_shapes = tuple(dynamic_shapes)
combined_args = tuple(args) + tuple(kwargs.values())
# run same check when we're processing shapes for export - is this too lazy?
_check_dynamic_shapes(dict(enumerate(combined_args)), dynamic_shapes) # type: ignore[arg-type]
tree_shapes = _tree_map_with_path(
_standardize_shapes, combined_args, dynamic_shapes, tree_name="inputs"
)
serialized_shapes = tree_map(_track_dim_from_dims, tree_shapes)
return _postprocess_serialized_shapes(serialized_shapes, dims, to_dict=to_dict)
def _load_dynamic_shapes(
spec: Union[DynamicShapesSpec, dict[str, Any]],
from_dict: Optional[bool] = False,
) -> Union[dict[str, Any], tuple[Any], list[Any], None]:
"""
Utility function for dynamic shapes serialization.
Deserializes a DynamicShapesSpec or corresponding dictionary into a dynamic_shapes input to export().
"""
import sympy
from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence
if from_dict:
if not isinstance(spec, dict):
raise UserError(
UserErrorType.INVALID_INPUT,
f"With from_dict=True, expected `spec` to be a dict, got {type(spec)}",
)
if sorted(spec.keys()) != ["dims", "dynamic_shapes"]:
raise UserError(
UserErrorType.INVALID_INPUT,
"With from_dict=True, expected `spec` to have keys `dims` and `dynamic_shapes`, "
f"instead found {spec.keys()}",
)
dims = {}
for k, v in spec["dims"].items():
if not isinstance(k, str):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected `spec['dims']` keys to be strings for symbols, got key {type(k)}",
)
if sorted(v.keys()) != ["derived", "max", "min"]:
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected `spec['dims']` values to have keys `derived`, `max`, and `min`, "
f"instead found {v.keys()}",
)
if not isinstance(v["min"], int):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected dims in `spec['dims']` to map `min` to an int, got {k}: {v['min']}",
)
if not isinstance(v["max"], int) or v["max"] is None:
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected dims in `spec['dims']` to map `max` to an int or None, got {k}: {v['max']}",
)
if not isinstance(v["derived"], list) or any(
not isinstance(d, str) for d in v["derived"]
):
raise UserError(
UserErrorType.INVALID_INPUT,
"Expected dims in `spec['dims']` to map `derived` to a list of derived expressions, "
f"got {k}: {v['derived']}",
)
dims[k] = RootDim(**v)
dynamic_shapes = spec["dynamic_shapes"]
else:
if not isinstance(spec, DynamicShapesSpec):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected `spec` to be a DynamicShapesSpec, got {type(spec)}",
)
dims = spec.dims
dynamic_shapes = spec.dynamic_shapes
if dynamic_shapes is None:
return None
dim_cache = {}
for name, info in dims.items():
symbol = sympy.sympify(name)
if not isinstance(symbol, sympy.Symbol):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected `spec['dims']` keys to be symbols, got {name}",
)
dim_cache[name] = Dim(name, min=info.min, max=info.max) # cache root dim
for _expr in info.derived:
expr = sympy.sympify(_expr)
if len(expr.free_symbols) != 1 or symbol not in expr.free_symbols:
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected derived expressions in to have {name} as the only free symbol, got {expr}",
)
if not _is_supported_equivalence(expr):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected derived expressions to be linear expressions, got {expr}",
)
modulus, remainder = sympy.polys.polytools.div(expr, symbol)
ddim = dim_cache[name]
if modulus != 1:
ddim = int(modulus) * ddim # type: ignore[assignment, operator]
if remainder != 0:
ddim = ddim + int(remainder) # type: ignore[assignment, operator]
dim_cache[_expr] = ddim # cache derived dims
def deserialize_shape(
val: Union[None, int, str],
) -> Union[None, int, Dim, _DimHint]:
if val is None or isinstance(val, int):
return val
elif val == "_DimHint.AUTO":
return _DimHint.AUTO()
elif val == "_DimHint.DYNAMIC":
return _DimHint.DYNAMIC()
elif val == "_DimHint.STATIC":
return _DimHint.STATIC()
if not isinstance(val, str):
raise UserError(
UserErrorType.INVALID_INPUT,
"Expected leaves in `spec['dynamic_shapes']` to be ints, None, Dim.AUTO/STATIC, symbols, "
f" or derived expressions, got {val}",
)
if val not in dim_cache:
raise UserError(
UserErrorType.INVALID_INPUT,
"Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, "
f"got {val} which is not in {dims.keys()}",
)
return dim_cache[val] # type: ignore[return-value]
return tree_map(deserialize_shape, dynamic_shapes)