mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy. Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs. Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs. There are two major feature gaps in the implementation: - DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls. - Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something? Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236 Approved by: https://github.com/zdevito, https://github.com/albanD
562 lines
18 KiB
Python
562 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
|
|
from ._dim_entry import _match_levels, DimEntry
|
|
from ._tensor_info import TensorInfo
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from . import Dim
|
|
|
|
|
|
def _safe_index(lst: list, item: Any) -> Optional[int]:
|
|
"""
|
|
Helper function to find index of item in list.
|
|
|
|
For DimEntry objects, uses __eq__ comparison which properly handles
|
|
both positional and Dim entries.
|
|
|
|
Returns the index if found, None if not found.
|
|
"""
|
|
for i, list_item in enumerate(lst):
|
|
# Use == for DimEntry objects as they have proper __eq__ implementation
|
|
if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
|
|
if list_item == item:
|
|
return i
|
|
elif list_item is item:
|
|
return i
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class IndexingInfo:
|
|
can_call_original: bool = False
|
|
advanced_indexing: bool = False
|
|
self_tensor: Optional[torch.Tensor] = None
|
|
flat_inputs: list[Any] = field(default_factory=list)
|
|
result_levels: list[DimEntry] = field(default_factory=list)
|
|
has_device: bool = False
|
|
|
|
|
|
def has_dims(obj: Any) -> bool:
|
|
"""
|
|
Check if an object has first-class dimensions.
|
|
|
|
This function checks if the object is either a Dim or a functorch Tensor
|
|
that has first-class dimensions, using the proper check_exact methods.
|
|
"""
|
|
from . import Dim, Tensor
|
|
|
|
return Dim.check_exact(obj) or Tensor.check_exact(obj)
|
|
|
|
|
|
def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
|
|
"""
|
|
Bind dimensions to size and calculate proper strides for dim packs.
|
|
"""
|
|
from . import DimensionBindError
|
|
|
|
rhs_prod = 1
|
|
for i, dim in enumerate(dims):
|
|
if not dim.is_bound:
|
|
# Check for multiple unbound dimensions
|
|
for j in range(i + 1, len(dims)):
|
|
if not dims[j].is_bound:
|
|
raise DimensionBindError(
|
|
f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
|
|
)
|
|
rhs_prod *= dims[j].size
|
|
|
|
# Calculate the size for this unbound dimension
|
|
if sz % rhs_prod != 0:
|
|
tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
|
|
raise DimensionBindError(
|
|
f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
|
|
)
|
|
|
|
inferred_size = sz // rhs_prod
|
|
dim.size = inferred_size
|
|
rhs_prod = sz
|
|
break
|
|
else:
|
|
rhs_prod *= dim.size
|
|
|
|
# Final validation that dimensions match
|
|
if rhs_prod != sz:
|
|
tup = tuple(dims)
|
|
raise DimensionBindError(
|
|
f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
|
|
)
|
|
|
|
# Calculate new sizes and strides for each dimension in the pack
|
|
# First calculate all strides by iterating in reverse
|
|
new_strides = [0] * len(dims)
|
|
current_stride = sd
|
|
for i in reversed(range(len(dims))):
|
|
new_strides[i] = current_stride
|
|
current_stride *= dims[i].size
|
|
|
|
# Then append sizes and strides in forward order
|
|
for i in range(len(dims)):
|
|
nsz.append(dims[i].size)
|
|
nsd.append(new_strides[i])
|
|
|
|
|
|
def slice_to_tuple(flat_inputs: list) -> tuple:
|
|
return tuple(flat_inputs)
|
|
|
|
|
|
def extractIndices(index: Any, indices: list) -> bool:
|
|
if isinstance(index, tuple): # mpy::tuple_view::check
|
|
indices.extend(index)
|
|
return True
|
|
elif isinstance(index, torch.Tensor): # THPVariable_Check
|
|
indices.append(index)
|
|
return False
|
|
elif not hasattr(index, "__iter__") or isinstance(
|
|
index, (str, bytes)
|
|
): # !mpy::is_sequence
|
|
indices.append(index)
|
|
return False
|
|
|
|
# Handle sequence case (list)
|
|
if isinstance(index, list):
|
|
if len(index) >= 32:
|
|
indices.extend(index)
|
|
return True
|
|
|
|
# Check each item in the sequence
|
|
for item in index:
|
|
if (
|
|
isinstance(item, (torch.Tensor, slice))
|
|
or hasattr(item, "__iter__")
|
|
or item is ...
|
|
or item is None
|
|
or has_dims(item)
|
|
):
|
|
indices.extend(index)
|
|
return True
|
|
|
|
# If we got here, treat as single index
|
|
indices.append(index)
|
|
return False
|
|
|
|
# Default case
|
|
indices.append(index)
|
|
return False
|
|
|
|
|
|
def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
|
|
self = args[0]
|
|
index = args[1]
|
|
|
|
iinfo = getsetitem(self, index, has_dims(self))
|
|
if iinfo.can_call_original:
|
|
# Call original tensor __getitem__ directly, bypassing __torch_function__
|
|
return torch.Tensor.__getitem__(self, index)
|
|
|
|
return invoke_getitem(iinfo)
|
|
|
|
|
|
def setitem(self: Any, index: Any, rhs: Any) -> None:
|
|
"""Set values in tensor using first-class dimensions."""
|
|
from . import DimensionBindError, TensorInfo
|
|
|
|
iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))
|
|
|
|
if iinfo.can_call_original:
|
|
# Call original tensor __setitem__ directly, bypassing __torch_function__
|
|
torch._C.TensorBase.__setitem__(self, index, rhs)
|
|
return
|
|
|
|
# Handle RHS tensor with dimensions
|
|
rhs_info = TensorInfo.create(rhs, False, False)
|
|
|
|
if rhs_info:
|
|
# Check that rhs dimensions are compatible with result dimensions
|
|
for l in rhs_info.levels:
|
|
if not l.is_positional():
|
|
# Find this dimension in result levels
|
|
found = False
|
|
for result_level in iinfo.result_levels:
|
|
if (
|
|
not result_level.is_positional()
|
|
and result_level.dim() is l.dim()
|
|
):
|
|
found = True
|
|
break
|
|
|
|
if not found:
|
|
# Create tuple representation of result levels for error message
|
|
result_dims: list[Union[int, Dim]] = []
|
|
for rl in iinfo.result_levels:
|
|
if rl.is_positional():
|
|
result_dims.append(rl.position())
|
|
else:
|
|
result_dims.append(rl.dim())
|
|
|
|
raise DimensionBindError(
|
|
f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
|
|
f"({tuple(result_dims)!r})"
|
|
)
|
|
|
|
# Match RHS tensor to result levels
|
|
assert rhs_info.tensor is not None, "Cannot match levels on None tensor"
|
|
matched_rhs = _match_levels(
|
|
rhs_info.tensor, rhs_info.levels, iinfo.result_levels
|
|
)
|
|
else:
|
|
matched_rhs = rhs
|
|
|
|
# For advanced indexing with dimensions, we need special handling
|
|
if iinfo.advanced_indexing:
|
|
# Use advanced indexing - the flat_inputs already contain matched tensors
|
|
tup = slice_to_tuple(iinfo.flat_inputs)
|
|
if iinfo.self_tensor is None:
|
|
raise RuntimeError("Cannot setitem on None tensor")
|
|
torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
|
|
else:
|
|
# Simple copy operation
|
|
if iinfo.self_tensor is None:
|
|
raise RuntimeError("Cannot copy to None tensor")
|
|
iinfo.self_tensor.copy_(matched_rhs)
|
|
|
|
|
|
def invoke_getitem(iinfo: IndexingInfo) -> Any:
|
|
if iinfo.advanced_indexing:
|
|
self_tensor = iinfo.self_tensor
|
|
tup = slice_to_tuple(iinfo.flat_inputs)
|
|
if self_tensor is None:
|
|
raise RuntimeError("Cannot getitem on None tensor")
|
|
rtensor = self_tensor[tup]
|
|
else:
|
|
rtensor = iinfo.self_tensor # type: ignore[assignment]
|
|
if rtensor is None:
|
|
raise RuntimeError("Cannot getitem on None tensor")
|
|
# rtensor is now guaranteed to be not None
|
|
|
|
# Create a Tensor with the proper dimensions using the class method
|
|
from . import Tensor
|
|
|
|
return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)
|
|
|
|
|
|
def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
|
|
from . import DimList # Import DimList for type checking
|
|
|
|
can_call_original_getitem = not tensors_have_dims
|
|
|
|
input_list = []
|
|
if has_dims(index):
|
|
input_list.append(index)
|
|
else:
|
|
is_sequence = extractIndices(index, input_list)
|
|
# nothing about first class dims here, fallback to getitem
|
|
if can_call_original_getitem and not is_sequence:
|
|
return IndexingInfo(can_call_original=True)
|
|
|
|
# Calculate how many dimensions have been indexed in order to compute the
|
|
# size of ... or expand a potentially unbound dimension list.
|
|
dims_indexed = 0
|
|
expanding_object = -1
|
|
unbound_dim_list = None
|
|
dimlists = [] # Track DimList positions for later processing
|
|
|
|
def check_expanding(i: int) -> None:
|
|
nonlocal expanding_object
|
|
if expanding_object != -1:
|
|
from . import DimensionBindError
|
|
|
|
raise DimensionBindError(
|
|
f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
|
|
f"{expanding_object} and {i}"
|
|
)
|
|
expanding_object = i
|
|
|
|
def is_dimpack(s: Any) -> bool:
|
|
from . import Dim
|
|
|
|
return (
|
|
isinstance(s, (tuple, list))
|
|
and len(s) > 0
|
|
and all(Dim.check_exact(item) for item in s)
|
|
)
|
|
|
|
has_dimpacks_or_none = False
|
|
for i, s in enumerate(input_list):
|
|
if has_dims(s):
|
|
can_call_original_getitem = False
|
|
dims_indexed += 1
|
|
elif s is ...:
|
|
check_expanding(i)
|
|
elif isinstance(s, DimList):
|
|
can_call_original_getitem = False
|
|
if not s.is_bound:
|
|
check_expanding(i)
|
|
unbound_dim_list = s
|
|
else:
|
|
dims_indexed += len(s._dims)
|
|
dimlists.append(i)
|
|
elif s is None:
|
|
has_dimpacks_or_none = True
|
|
elif is_dimpack(s):
|
|
can_call_original_getitem = False
|
|
has_dimpacks_or_none = True
|
|
dims_indexed += 1
|
|
else:
|
|
dims_indexed += 1
|
|
|
|
# Early return if we can use original getitem
|
|
if can_call_original_getitem:
|
|
return IndexingInfo(can_call_original=True)
|
|
|
|
self_info = TensorInfo.create(self, False, True)
|
|
total_dims = len(self_info.levels) # Total dimensions (positional + named)
|
|
if dims_indexed > total_dims:
|
|
raise ValueError(
|
|
f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
|
|
)
|
|
|
|
# Expand any unbound dimension list, or expand ... into individual : slices.
|
|
expanding_dims = total_dims - dims_indexed
|
|
if expanding_object != -1:
|
|
if unbound_dim_list is not None:
|
|
# Bind unbound dimension list to the expanding dimensions
|
|
unbound_dim_list.bind_len(expanding_dims)
|
|
else:
|
|
# Expand ... into slice(None) objects
|
|
no_slices = [slice(None)] * expanding_dims
|
|
input_list = (
|
|
input_list[:expanding_object]
|
|
+ no_slices
|
|
+ input_list[expanding_object + 1 :]
|
|
)
|
|
|
|
# Flatten out any dimensions stored in dimlist elements directly into the inputs
|
|
# Process in reverse order to maintain indices
|
|
for i in range(len(dimlists) - 1, -1, -1):
|
|
idx = dimlists[i]
|
|
|
|
# We added more elements to input because of ...
|
|
# so we need to also adjust the index to get back to where the
|
|
# dimlist existed
|
|
if (
|
|
unbound_dim_list is None
|
|
and expanding_object != -1
|
|
and idx > expanding_object
|
|
):
|
|
idx += expanding_dims
|
|
|
|
dl = input_list[idx]
|
|
|
|
# PRIVATE here naughty
|
|
input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]
|
|
|
|
return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)
|
|
|
|
|
|
def getsetitem_flat(
|
|
self_info: TensorInfo,
|
|
input_list: list,
|
|
keys: list[DimEntry],
|
|
values: list,
|
|
has_dimpacks_or_none: bool,
|
|
) -> IndexingInfo:
|
|
from . import Dim
|
|
|
|
# Track dimension usage
|
|
seen_dims: list[Any] = []
|
|
seen_dims_nuses: list[int] = []
|
|
|
|
def add_dim(dim: Any) -> None:
|
|
# Use safe indexing to avoid triggering __torch_function__ on Dim objects
|
|
idx = _safe_index(seen_dims, dim)
|
|
if idx is not None:
|
|
seen_dims_nuses[idx] += 1
|
|
else:
|
|
seen_dims.append(dim)
|
|
seen_dims_nuses.append(1)
|
|
|
|
flat_inputs = []
|
|
tensor_inputs: list[Any] = []
|
|
device_holding_tensor = None
|
|
|
|
def append_flat_handle(handle: Any) -> None:
|
|
flat_inputs.append(handle)
|
|
tensor_inputs.append(None)
|
|
|
|
def append_tensor_input(ti: TensorInfo) -> None:
|
|
flat_inputs.append(None)
|
|
tensor_inputs.append(ti)
|
|
nonlocal device_holding_tensor
|
|
if ti.has_device and device_holding_tensor is None:
|
|
device_holding_tensor = ti.tensor
|
|
|
|
nsz = []
|
|
nsd = []
|
|
if self_info.tensor is None:
|
|
raise RuntimeError("Cannot get size/stride on None tensor")
|
|
sz = self_info.tensor.size()
|
|
sd = self_info.tensor.stride()
|
|
|
|
def append_size(i: int) -> None:
|
|
if has_dimpacks_or_none:
|
|
nsz.append(sz[i])
|
|
nsd.append(sd[i])
|
|
|
|
input_it = input_list[:]
|
|
|
|
def parse_nones() -> None:
|
|
nonlocal input_it
|
|
while input_it and input_it[0] is None:
|
|
append_flat_handle(slice(None))
|
|
nsz.append(1)
|
|
nsd.append(0)
|
|
input_it = input_it[1:]
|
|
|
|
def append_item(i: int, arg: Any) -> None:
|
|
if Dim.check_exact(arg):
|
|
d = arg
|
|
if d._size == -1:
|
|
d.size = sz[i]
|
|
add_dim(d)
|
|
append_size(i)
|
|
append_flat_handle(arg)
|
|
return
|
|
|
|
info = TensorInfo.create(arg, False, False)
|
|
if info:
|
|
append_size(i)
|
|
append_tensor_input(info)
|
|
for level in info.levels:
|
|
if not level.is_positional():
|
|
add_dim(level.dim())
|
|
return
|
|
|
|
if has_dimpacks_or_none:
|
|
if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
|
|
# dim pack
|
|
dim_pack = list(arg)
|
|
for d in dim_pack:
|
|
add_dim(d)
|
|
append_flat_handle(d)
|
|
_bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
|
|
return
|
|
|
|
append_size(i)
|
|
append_flat_handle(arg)
|
|
|
|
# Match indexing expressions with tensor dimensions
|
|
for i, level in enumerate(self_info.levels):
|
|
# Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
|
|
idx = _safe_index(keys, level)
|
|
if idx is not None:
|
|
append_item(i, values[idx])
|
|
else:
|
|
if level.is_positional():
|
|
parse_nones()
|
|
if not input_it:
|
|
append_flat_handle(slice(None))
|
|
append_size(i)
|
|
else:
|
|
arg = input_it[0]
|
|
input_it = input_it[1:]
|
|
append_item(i, arg)
|
|
else:
|
|
add_dim(level.dim())
|
|
append_flat_handle(level.dim())
|
|
append_size(i)
|
|
|
|
parse_nones()
|
|
|
|
# Restride tensor if needed
|
|
if has_dimpacks_or_none and nsz:
|
|
if self_info.tensor is None:
|
|
raise RuntimeError("Cannot restride None tensor")
|
|
self_tensor = self_info.tensor.as_strided(
|
|
nsz, nsd, self_info.tensor.storage_offset()
|
|
)
|
|
else:
|
|
self_tensor = self_info.tensor
|
|
|
|
# Determine result shape and indexing requirements
|
|
result_levels: list[Any] = []
|
|
index_levels = []
|
|
tensor_insert_point = -1
|
|
requires_getindex = False
|
|
|
|
def mark_tensor_index() -> None:
|
|
nonlocal tensor_insert_point
|
|
if tensor_insert_point == -1:
|
|
tensor_insert_point = len(result_levels)
|
|
elif tensor_insert_point != len(result_levels):
|
|
tensor_insert_point = 0
|
|
|
|
for i, inp in enumerate(flat_inputs):
|
|
if tensor_inputs[i] is not None:
|
|
requires_getindex = True
|
|
mark_tensor_index()
|
|
for level in tensor_inputs[i].levels:
|
|
if level not in index_levels:
|
|
index_levels.append(level)
|
|
elif Dim.check_exact(inp):
|
|
d = inp
|
|
# Use safe indexing to avoid triggering __torch_function__
|
|
dim_idx = _safe_index(seen_dims, d)
|
|
assert dim_idx is not None, f"Dim {d} not found in seen_dims"
|
|
if seen_dims_nuses[dim_idx] == 1:
|
|
flat_inputs[i] = slice(None)
|
|
result_levels.append(DimEntry(d))
|
|
else:
|
|
requires_getindex = True
|
|
flat_inputs[i] = None
|
|
tensor_inputs[i] = TensorInfo(
|
|
d._get_range(), [DimEntry(d)], False, None
|
|
)
|
|
if DimEntry(d) not in index_levels:
|
|
index_levels.append(DimEntry(d))
|
|
mark_tensor_index()
|
|
else:
|
|
if inp != slice(None):
|
|
requires_getindex = True
|
|
if not isinstance(inp, int):
|
|
result_levels.append(DimEntry(-1))
|
|
|
|
# Insert indexing dimensions at first tensor use point
|
|
if tensor_insert_point != -1:
|
|
for level in reversed(index_levels):
|
|
result_levels.insert(tensor_insert_point, level)
|
|
|
|
# Match tensors to indexing shape
|
|
if requires_getindex:
|
|
for i in range(len(flat_inputs)):
|
|
if tensor_inputs[i] is not None:
|
|
t = tensor_inputs[i].tensor
|
|
assert t is not None, "TensorInfo should have valid tensor data"
|
|
if (
|
|
not tensor_inputs[i].has_device
|
|
and device_holding_tensor is not None
|
|
):
|
|
t = t.to(device_holding_tensor.device)
|
|
flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)
|
|
|
|
# Number positional dimensions correctly
|
|
seen_positionals = 0
|
|
for i in reversed(range(len(result_levels))):
|
|
if result_levels[i].is_positional():
|
|
seen_positionals += 1
|
|
result_levels[i] = DimEntry(-seen_positionals)
|
|
|
|
return IndexingInfo(
|
|
can_call_original=False,
|
|
advanced_indexing=requires_getindex,
|
|
self_tensor=self_tensor,
|
|
flat_inputs=flat_inputs,
|
|
result_levels=result_levels,
|
|
has_device=self_info.has_device,
|
|
)
|