Files
pytorch/functorch/dim/_order.py
Edward Yang 97eb7a281d torchdim Python port (#160236)
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.

Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.

Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.

There are two major feature gaps in the implementation:

- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it.  However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-21 03:01:04 +00:00

215 lines
6.8 KiB
Python

from __future__ import annotations
from typing import Any, TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Sequence
import torch # noqa: TC002
from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
def _wrap_dim(arg: Any, orig_ndim: int, allow_none: bool = True) -> DimEntry:
"""
Convert various dimension representations to DimEntry.
Args:
arg: The argument to convert (Dim, int, or other)
orig_ndim: Original number of dimensions
allow_none: Whether to allow None values
Returns:
DimEntry representation of the dimension
"""
from . import Dim
if arg is None and allow_none:
return DimEntry() # None entry
elif isinstance(arg, Dim):
return DimEntry(arg)
elif isinstance(arg, int):
if arg < 0:
pos = arg
else:
pos = arg - orig_ndim
return DimEntry(pos)
else:
return DimEntry()
def order(
tensor_or_dim: Union[torch.Tensor, Any], *dims: Union[Any, Sequence[Any]]
) -> torch.Tensor:
"""
Reorder the dimensions of a tensor or create a tensor from a dimension.
It allows reordering tensor dimensions using first-class dimensions and
positional indices.
Args:
tensor_or_dim: Input tensor with first-class dimensions, or a Dim object
*dims: Dimensions or sequences of dimensions specifying the new order
Returns:
Tensor with reordered dimensions
Examples:
>>> import torch
>>> from functorch.dim import dims
>>> batch, channel, height, width = dims(4)
>>> x = torch.randn(2, 3, 4, 5)[batch, channel, height, width]
>>> # Reorder to [height, width, batch, channel]
>>> y = order(x, height, width, batch, channel)
"""
from . import Dim, DimList, Tensor
# Handle first argument - tensor or dimension
if isinstance(tensor_or_dim, Tensor):
# First-class tensor
orig_levels = tensor_or_dim._levels[:]
data = tensor_or_dim._tensor
has_device = tensor_or_dim._has_device
elif isinstance(tensor_or_dim, Dim):
# Single dimension - create range tensor
orig_levels = [DimEntry(tensor_or_dim)]
data = tensor_or_dim._get_range()
has_device = False
else:
raise ValueError("First argument must be a Tensor or Dim object")
flat_positional_dims = []
to_flatten = [] # List of (start_index, length) pairs for flattening
levels = orig_levels[:]
orig_ndim = ndim_of_levels(levels)
def append_dim(d: DimEntry) -> None:
"""Add a dimension to the reordering, removing it from available levels."""
try:
idx = levels.index(d)
except ValueError:
idx = None
if idx is None:
if d.is_positional():
raise ValueError(
f"tensor has {orig_ndim} positional dimensions, but {d.position() + orig_ndim} specified, "
f"or it was specified twice"
)
else:
raise ValueError(
f"tensor does not contain dim {d.dim()} or it was specified twice"
)
levels[idx] = DimEntry()
flat_positional_dims.append(d)
n_new_positional = 0
# Process each dimension argument
for arg in dims:
entry = _wrap_dim(arg, orig_ndim, False)
if not entry.is_none():
append_dim(entry)
n_new_positional += 1
elif isinstance(arg, DimList):
# Handle DimList
for dim in arg._dims:
append_dim(DimEntry(dim))
n_new_positional += 1
else:
# Handle sequences of dimensions for flattening
n_new_positional += 1
if not hasattr(arg, "__iter__"):
raise ValueError("expected a Dim, List[Dim], or Sequence[Dim]")
# Convert to list to get length
seq = list(arg)
to_flatten.append((len(flat_positional_dims), len(seq)))
for item in seq:
entry = _wrap_dim(item, orig_ndim, False)
if entry.is_none():
raise ValueError("expected a Dim or int")
append_dim(entry)
# Build new level ordering
insert_point = -1
new_levels: list[DimEntry] = []
# Add remaining (non-reordered) levels, finding insertion point for new dimensions
for level in levels:
if level.is_none():
continue
if level.is_positional():
if insert_point == -1:
insert_point = len(new_levels)
new_levels.extend(flat_positional_dims)
new_levels.append(level)
# If no positional dimensions found, append new dims at the end
if insert_point == -1:
insert_point = len(new_levels)
new_levels.extend(flat_positional_dims)
# Match tensor to new level structure
assert data is not None, "Cannot reorder None tensor"
ndata = _match_levels(data, orig_levels, new_levels)
# Handle dimension flattening if requested
if to_flatten:
# Now build the reshape target
view_shape = []
sizes = ndata.size()
# Add dimensions before the reordered ones
for i in range(insert_point):
view_shape.append(sizes[i])
# Process flattening groups
i = 0
for start_idx, length in to_flatten:
# Add individual dims before this flattening group
while i < start_idx:
view_shape.append(sizes[insert_point + i])
i += 1
# Flatten the group
new_size = 1
for j in range(length):
new_size *= sizes[insert_point + i + j]
view_shape.append(new_size)
i += length
# Add remaining individual dims
while i < len(flat_positional_dims):
view_shape.append(sizes[insert_point + i])
i += 1
# Add dimensions after the reordered ones
for i in range(insert_point + len(flat_positional_dims), len(levels)):
view_shape.append(sizes[i])
# Update levels by removing flattened dimensions
n_to_remove = len(flat_positional_dims) - n_new_positional
if n_to_remove > 0:
# Remove flattened levels
new_levels = (
new_levels[:insert_point] + new_levels[insert_point + n_to_remove :]
)
ndata = ndata.reshape(view_shape)
# Renumber positional dimensions (negative indexing from the right)
seen = 0
for i in range(len(new_levels) - 1, -1, -1):
if new_levels[i].is_positional() or (
i >= insert_point and i < insert_point + n_new_positional
):
seen -= 1
new_levels[i] = DimEntry(seen)
result = Tensor.from_positional(ndata, new_levels, has_device)
return result # type: ignore[return-value]