Files
pytorch/functorch/dim/_getsetitem.py
Edward Yang 97eb7a281d torchdim Python port (#160236)
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.

Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.

Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.

There are two major feature gaps in the implementation:

- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it.  However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
2025-09-21 03:01:04 +00:00

562 lines
18 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
from ._dim_entry import _match_levels, DimEntry
from ._tensor_info import TensorInfo
if TYPE_CHECKING:
from . import Dim
def _safe_index(lst: list, item: Any) -> Optional[int]:
"""
Helper function to find index of item in list.
For DimEntry objects, uses __eq__ comparison which properly handles
both positional and Dim entries.
Returns the index if found, None if not found.
"""
for i, list_item in enumerate(lst):
# Use == for DimEntry objects as they have proper __eq__ implementation
if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
if list_item == item:
return i
elif list_item is item:
return i
return None
@dataclass
class IndexingInfo:
can_call_original: bool = False
advanced_indexing: bool = False
self_tensor: Optional[torch.Tensor] = None
flat_inputs: list[Any] = field(default_factory=list)
result_levels: list[DimEntry] = field(default_factory=list)
has_device: bool = False
def has_dims(obj: Any) -> bool:
"""
Check if an object has first-class dimensions.
This function checks if the object is either a Dim or a functorch Tensor
that has first-class dimensions, using the proper check_exact methods.
"""
from . import Dim, Tensor
return Dim.check_exact(obj) or Tensor.check_exact(obj)
def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
"""
Bind dimensions to size and calculate proper strides for dim packs.
"""
from . import DimensionBindError
rhs_prod = 1
for i, dim in enumerate(dims):
if not dim.is_bound:
# Check for multiple unbound dimensions
for j in range(i + 1, len(dims)):
if not dims[j].is_bound:
raise DimensionBindError(
f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
)
rhs_prod *= dims[j].size
# Calculate the size for this unbound dimension
if sz % rhs_prod != 0:
tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
raise DimensionBindError(
f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
)
inferred_size = sz // rhs_prod
dim.size = inferred_size
rhs_prod = sz
break
else:
rhs_prod *= dim.size
# Final validation that dimensions match
if rhs_prod != sz:
tup = tuple(dims)
raise DimensionBindError(
f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
)
# Calculate new sizes and strides for each dimension in the pack
# First calculate all strides by iterating in reverse
new_strides = [0] * len(dims)
current_stride = sd
for i in reversed(range(len(dims))):
new_strides[i] = current_stride
current_stride *= dims[i].size
# Then append sizes and strides in forward order
for i in range(len(dims)):
nsz.append(dims[i].size)
nsd.append(new_strides[i])
def slice_to_tuple(flat_inputs: list) -> tuple:
return tuple(flat_inputs)
def extractIndices(index: Any, indices: list) -> bool:
if isinstance(index, tuple): # mpy::tuple_view::check
indices.extend(index)
return True
elif isinstance(index, torch.Tensor): # THPVariable_Check
indices.append(index)
return False
elif not hasattr(index, "__iter__") or isinstance(
index, (str, bytes)
): # !mpy::is_sequence
indices.append(index)
return False
# Handle sequence case (list)
if isinstance(index, list):
if len(index) >= 32:
indices.extend(index)
return True
# Check each item in the sequence
for item in index:
if (
isinstance(item, (torch.Tensor, slice))
or hasattr(item, "__iter__")
or item is ...
or item is None
or has_dims(item)
):
indices.extend(index)
return True
# If we got here, treat as single index
indices.append(index)
return False
# Default case
indices.append(index)
return False
def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
self = args[0]
index = args[1]
iinfo = getsetitem(self, index, has_dims(self))
if iinfo.can_call_original:
# Call original tensor __getitem__ directly, bypassing __torch_function__
return torch.Tensor.__getitem__(self, index)
return invoke_getitem(iinfo)
def setitem(self: Any, index: Any, rhs: Any) -> None:
"""Set values in tensor using first-class dimensions."""
from . import DimensionBindError, TensorInfo
iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))
if iinfo.can_call_original:
# Call original tensor __setitem__ directly, bypassing __torch_function__
torch._C.TensorBase.__setitem__(self, index, rhs)
return
# Handle RHS tensor with dimensions
rhs_info = TensorInfo.create(rhs, False, False)
if rhs_info:
# Check that rhs dimensions are compatible with result dimensions
for l in rhs_info.levels:
if not l.is_positional():
# Find this dimension in result levels
found = False
for result_level in iinfo.result_levels:
if (
not result_level.is_positional()
and result_level.dim() is l.dim()
):
found = True
break
if not found:
# Create tuple representation of result levels for error message
result_dims: list[Union[int, Dim]] = []
for rl in iinfo.result_levels:
if rl.is_positional():
result_dims.append(rl.position())
else:
result_dims.append(rl.dim())
raise DimensionBindError(
f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
f"({tuple(result_dims)!r})"
)
# Match RHS tensor to result levels
assert rhs_info.tensor is not None, "Cannot match levels on None tensor"
matched_rhs = _match_levels(
rhs_info.tensor, rhs_info.levels, iinfo.result_levels
)
else:
matched_rhs = rhs
# For advanced indexing with dimensions, we need special handling
if iinfo.advanced_indexing:
# Use advanced indexing - the flat_inputs already contain matched tensors
tup = slice_to_tuple(iinfo.flat_inputs)
if iinfo.self_tensor is None:
raise RuntimeError("Cannot setitem on None tensor")
torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
else:
# Simple copy operation
if iinfo.self_tensor is None:
raise RuntimeError("Cannot copy to None tensor")
iinfo.self_tensor.copy_(matched_rhs)
def invoke_getitem(iinfo: IndexingInfo) -> Any:
if iinfo.advanced_indexing:
self_tensor = iinfo.self_tensor
tup = slice_to_tuple(iinfo.flat_inputs)
if self_tensor is None:
raise RuntimeError("Cannot getitem on None tensor")
rtensor = self_tensor[tup]
else:
rtensor = iinfo.self_tensor # type: ignore[assignment]
if rtensor is None:
raise RuntimeError("Cannot getitem on None tensor")
# rtensor is now guaranteed to be not None
# Create a Tensor with the proper dimensions using the class method
from . import Tensor
return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)
def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
from . import DimList # Import DimList for type checking
can_call_original_getitem = not tensors_have_dims
input_list = []
if has_dims(index):
input_list.append(index)
else:
is_sequence = extractIndices(index, input_list)
# nothing about first class dims here, fallback to getitem
if can_call_original_getitem and not is_sequence:
return IndexingInfo(can_call_original=True)
# Calculate how many dimensions have been indexed in order to compute the
# size of ... or expand a potentially unbound dimension list.
dims_indexed = 0
expanding_object = -1
unbound_dim_list = None
dimlists = [] # Track DimList positions for later processing
def check_expanding(i: int) -> None:
nonlocal expanding_object
if expanding_object != -1:
from . import DimensionBindError
raise DimensionBindError(
f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
f"{expanding_object} and {i}"
)
expanding_object = i
def is_dimpack(s: Any) -> bool:
from . import Dim
return (
isinstance(s, (tuple, list))
and len(s) > 0
and all(Dim.check_exact(item) for item in s)
)
has_dimpacks_or_none = False
for i, s in enumerate(input_list):
if has_dims(s):
can_call_original_getitem = False
dims_indexed += 1
elif s is ...:
check_expanding(i)
elif isinstance(s, DimList):
can_call_original_getitem = False
if not s.is_bound:
check_expanding(i)
unbound_dim_list = s
else:
dims_indexed += len(s._dims)
dimlists.append(i)
elif s is None:
has_dimpacks_or_none = True
elif is_dimpack(s):
can_call_original_getitem = False
has_dimpacks_or_none = True
dims_indexed += 1
else:
dims_indexed += 1
# Early return if we can use original getitem
if can_call_original_getitem:
return IndexingInfo(can_call_original=True)
self_info = TensorInfo.create(self, False, True)
total_dims = len(self_info.levels) # Total dimensions (positional + named)
if dims_indexed > total_dims:
raise ValueError(
f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
)
# Expand any unbound dimension list, or expand ... into individual : slices.
expanding_dims = total_dims - dims_indexed
if expanding_object != -1:
if unbound_dim_list is not None:
# Bind unbound dimension list to the expanding dimensions
unbound_dim_list.bind_len(expanding_dims)
else:
# Expand ... into slice(None) objects
no_slices = [slice(None)] * expanding_dims
input_list = (
input_list[:expanding_object]
+ no_slices
+ input_list[expanding_object + 1 :]
)
# Flatten out any dimensions stored in dimlist elements directly into the inputs
# Process in reverse order to maintain indices
for i in range(len(dimlists) - 1, -1, -1):
idx = dimlists[i]
# We added more elements to input because of ...
# so we need to also adjust the index to get back to where the
# dimlist existed
if (
unbound_dim_list is None
and expanding_object != -1
and idx > expanding_object
):
idx += expanding_dims
dl = input_list[idx]
# PRIVATE here naughty
input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]
return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)
def getsetitem_flat(
self_info: TensorInfo,
input_list: list,
keys: list[DimEntry],
values: list,
has_dimpacks_or_none: bool,
) -> IndexingInfo:
from . import Dim
# Track dimension usage
seen_dims: list[Any] = []
seen_dims_nuses: list[int] = []
def add_dim(dim: Any) -> None:
# Use safe indexing to avoid triggering __torch_function__ on Dim objects
idx = _safe_index(seen_dims, dim)
if idx is not None:
seen_dims_nuses[idx] += 1
else:
seen_dims.append(dim)
seen_dims_nuses.append(1)
flat_inputs = []
tensor_inputs: list[Any] = []
device_holding_tensor = None
def append_flat_handle(handle: Any) -> None:
flat_inputs.append(handle)
tensor_inputs.append(None)
def append_tensor_input(ti: TensorInfo) -> None:
flat_inputs.append(None)
tensor_inputs.append(ti)
nonlocal device_holding_tensor
if ti.has_device and device_holding_tensor is None:
device_holding_tensor = ti.tensor
nsz = []
nsd = []
if self_info.tensor is None:
raise RuntimeError("Cannot get size/stride on None tensor")
sz = self_info.tensor.size()
sd = self_info.tensor.stride()
def append_size(i: int) -> None:
if has_dimpacks_or_none:
nsz.append(sz[i])
nsd.append(sd[i])
input_it = input_list[:]
def parse_nones() -> None:
nonlocal input_it
while input_it and input_it[0] is None:
append_flat_handle(slice(None))
nsz.append(1)
nsd.append(0)
input_it = input_it[1:]
def append_item(i: int, arg: Any) -> None:
if Dim.check_exact(arg):
d = arg
if d._size == -1:
d.size = sz[i]
add_dim(d)
append_size(i)
append_flat_handle(arg)
return
info = TensorInfo.create(arg, False, False)
if info:
append_size(i)
append_tensor_input(info)
for level in info.levels:
if not level.is_positional():
add_dim(level.dim())
return
if has_dimpacks_or_none:
if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
# dim pack
dim_pack = list(arg)
for d in dim_pack:
add_dim(d)
append_flat_handle(d)
_bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
return
append_size(i)
append_flat_handle(arg)
# Match indexing expressions with tensor dimensions
for i, level in enumerate(self_info.levels):
# Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
idx = _safe_index(keys, level)
if idx is not None:
append_item(i, values[idx])
else:
if level.is_positional():
parse_nones()
if not input_it:
append_flat_handle(slice(None))
append_size(i)
else:
arg = input_it[0]
input_it = input_it[1:]
append_item(i, arg)
else:
add_dim(level.dim())
append_flat_handle(level.dim())
append_size(i)
parse_nones()
# Restride tensor if needed
if has_dimpacks_or_none and nsz:
if self_info.tensor is None:
raise RuntimeError("Cannot restride None tensor")
self_tensor = self_info.tensor.as_strided(
nsz, nsd, self_info.tensor.storage_offset()
)
else:
self_tensor = self_info.tensor
# Determine result shape and indexing requirements
result_levels: list[Any] = []
index_levels = []
tensor_insert_point = -1
requires_getindex = False
def mark_tensor_index() -> None:
nonlocal tensor_insert_point
if tensor_insert_point == -1:
tensor_insert_point = len(result_levels)
elif tensor_insert_point != len(result_levels):
tensor_insert_point = 0
for i, inp in enumerate(flat_inputs):
if tensor_inputs[i] is not None:
requires_getindex = True
mark_tensor_index()
for level in tensor_inputs[i].levels:
if level not in index_levels:
index_levels.append(level)
elif Dim.check_exact(inp):
d = inp
# Use safe indexing to avoid triggering __torch_function__
dim_idx = _safe_index(seen_dims, d)
assert dim_idx is not None, f"Dim {d} not found in seen_dims"
if seen_dims_nuses[dim_idx] == 1:
flat_inputs[i] = slice(None)
result_levels.append(DimEntry(d))
else:
requires_getindex = True
flat_inputs[i] = None
tensor_inputs[i] = TensorInfo(
d._get_range(), [DimEntry(d)], False, None
)
if DimEntry(d) not in index_levels:
index_levels.append(DimEntry(d))
mark_tensor_index()
else:
if inp != slice(None):
requires_getindex = True
if not isinstance(inp, int):
result_levels.append(DimEntry(-1))
# Insert indexing dimensions at first tensor use point
if tensor_insert_point != -1:
for level in reversed(index_levels):
result_levels.insert(tensor_insert_point, level)
# Match tensors to indexing shape
if requires_getindex:
for i in range(len(flat_inputs)):
if tensor_inputs[i] is not None:
t = tensor_inputs[i].tensor
assert t is not None, "TensorInfo should have valid tensor data"
if (
not tensor_inputs[i].has_device
and device_holding_tensor is not None
):
t = t.to(device_holding_tensor.device)
flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)
# Number positional dimensions correctly
seen_positionals = 0
for i in reversed(range(len(result_levels))):
if result_levels[i].is_positional():
seen_positionals += 1
result_levels[i] = DimEntry(-seen_positionals)
return IndexingInfo(
can_call_original=False,
advanced_indexing=requires_getindex,
self_tensor=self_tensor,
flat_inputs=flat_inputs,
result_levels=result_levels,
has_device=self_info.has_device,
)