torchdim Python port (#160236)

The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy.

Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs.

Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs.

There are two major feature gaps in the implementation:

- DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it.  However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls.
- Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something?

Signed-off-by: Edward Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236
Approved by: https://github.com/zdevito, https://github.com/albanD
This commit is contained in:
Edward Yang
2025-09-20 12:21:15 -04:00
committed by PyTorch MergeBot
parent 2887f3fde4
commit 97eb7a281d
22 changed files with 3725 additions and 123 deletions

View File

@ -746,12 +746,14 @@ These compilers and language have syntax and semantics that resemble the loop-le
Dimension objects are just an extension of the existing PyTorch tensors and eager semantics, so there is no friction switching between normal Python code and code that uses them. However, since loops over the dimensions are defined implicitly, they can still execute in Python with good performance compared to explicit loops. Furthermore, with dimension objects, a tensors containing dimensions can compute through code that is oblivious to the dimension such as batching examples. There is no need to separate code into 'compiled' vs 'eager'.
In this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries.
In this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries. Note, however, that first class dimensions are not natively compiled, so if you write code that performs many outer products with the expectation of it being fused, you will generally not get good performance or memory use (except for matrix-multiply-like patterns specifically.)
Performance Expectations
========================
First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can incorporate more fusion optimization to further improve performance of this style of code.
First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel.
Originally, there was a C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. However, this implementation had some manual memory managemetn bugs and was not kept up to date with CPython updates. The latest Python implementation is two orders of magnitude slower due to CPU overhead; for overhead sensitive applications you should compile the code to eliminate this overhead.
## License

File diff suppressed because it is too large Load Diff

127
functorch/dim/_dim_entry.py Normal file
View File

@ -0,0 +1,127 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Sequence
from . import Dim
import torch # noqa: TC002
# NB: The old code represented dimension was from as negative number, so we
# follow this convention even though it shouldn't be necessary now
class DimEntry:
# The dimension this is from the rhs, or a FCD
data: Union[Dim, int]
def __init__(self, data: Union[Dim, int, None] = None) -> None:
from . import Dim
if type(data) is int:
assert data < 0
elif data is None:
data = 0
else:
assert isinstance(data, Dim)
self.data = data
def __eq__(self, other: object) -> bool:
if not isinstance(other, DimEntry):
return False
# Use 'is' for Dim objects to avoid triggering __torch_function__
# Use '==' only for positional (int) comparisons
if self.is_positional() and other.is_positional():
# Both are positional (ints)
return self.data == other.data
elif not self.is_positional() and not other.is_positional():
# Both are Dim objects - use 'is' to avoid __eq__
return self.data is other.data
else:
# One is positional, one is Dim - they can't be equal
return False
def is_positional(self) -> bool:
return type(self.data) is int and self.data < 0
def is_none(self) -> bool:
# Use isinstance to check for Dim objects, avoid triggering __torch_function__
from . import Dim
if isinstance(self.data, Dim):
# This is a Dim object, it can't be "none" (which is represented by 0)
return False
else:
# This is an int or other type
return self.data == 0
def position(self) -> int:
assert isinstance(self.data, int)
return self.data
def dim(self) -> Dim:
assert not isinstance(self.data, int)
return self.data
def __repr__(self) -> str:
return repr(self.data)
def ndim_of_levels(levels: Sequence[DimEntry]) -> int:
r = 0
for l in levels:
if l.is_positional():
r += 1
return r
def _match_levels(
tensor: torch.Tensor,
from_levels: list[DimEntry],
to_levels: list[DimEntry],
drop_levels: bool = False,
) -> torch.Tensor:
"""
Reshape a tensor to match target levels using as_strided.
Args:
tensor: Input tensor to reshape
from_levels: Current levels of the tensor
to_levels: Target levels to match
drop_levels: If True, missing dimensions are assumed to have stride 0
Returns:
Reshaped tensor
"""
if from_levels == to_levels:
return tensor
sizes = tensor.size()
strides = tensor.stride()
if not drop_levels:
assert len(from_levels) <= len(to_levels), (
"Cannot expand dimensions without drop_levels"
)
new_sizes = []
new_strides = []
for level in to_levels:
# Find index of this level in from_levels
try:
idx = from_levels.index(level)
except ValueError:
# Level not found in from_levels
if level.is_positional():
new_sizes.append(1)
else:
new_sizes.append(level.dim().size)
new_strides.append(0)
else:
new_sizes.append(sizes[idx])
new_strides.append(strides[idx])
return tensor.as_strided(new_sizes, new_strides, tensor.storage_offset())

View File

@ -0,0 +1,139 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING
import torch
from ._dim_entry import DimEntry
if TYPE_CHECKING:
from . import Dim, Tensor
class EnableAllLayers:
"""
RAII-style context manager for enabling functorch vmap layers.
It manages the creation and cleanup of functorch dynamic layers.
This is probably one of the more algorithmically important parts of first
class dims. Intuitively, FCD can be thought of as another way of using
vmap, where you don't actually have to vmap at the top level, instead the
vmaps are implicitly determined by inspecting the bound dimensions on the
FCD tensors involved in a compute (this is similar to our concept of
non-lexical modes that we spent a long time talking about years ago). But
under the hood you still need to actually enable the vmap mode. So once
FCD has determined all of the dims we are batching over, it needs to
enable all those layers so functorch can actually apply the batching
rules. Therefore enable all layers!
"""
levels_start: int
levels_to_dim: list[Dim]
def __init__(self, levels: list[DimEntry]):
"""
Initialize and push dynamic layers for all first-class dimensions.
Args:
levels: List of dimension entries to create layers for
"""
from . import Dim
self.levels_start = 0
self.levels_to_dim = []
for l in levels:
if not l.is_positional():
d = l.dim()
assert isinstance(d, Dim)
self.levels_to_dim.append(d)
# Sort by level for stable ordering
self.levels_to_dim.sort(key=lambda d: d._level)
def __enter__(self) -> EnableAllLayers: # noqa: PYI034
# Create functorch dynamic layers
for i, dim in enumerate(self.levels_to_dim):
batch_size = dim.size
level = torch._C._functorch._vmap_increment_nesting(batch_size, "different")
if i == 0:
self.levels_start = level
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Clean up dynamic layers in reverse order."""
to_remove = self.levels_start + len(self.levels_to_dim) - 1
for i in range(len(self.levels_to_dim)):
popped = torch._C._functorch._vmap_decrement_nesting()
assert popped == to_remove - i, (
f"Expected layer {to_remove - i}, got {popped}"
)
def from_batched(self, batchedtensor: torch.Tensor, has_device: bool) -> Tensor:
"""
Create a Tensor from a batched tensor by unwrapping functorch layers.
Args:
batchedtensor: Batched tensor from functorch operation
has_device: Whether tensor has device info
Returns:
Tensor with appropriate levels
"""
# Create positional levels for base dimensions
levels: list[DimEntry] = []
for i in range(-batchedtensor.dim(), 0):
levels.append(DimEntry(i))
tensor = batchedtensor
while torch._C._functorch.is_batchedtensor(tensor):
level = torch._C._functorch.maybe_get_level(tensor)
assert level is not None
assert level >= self.levels_start and level < self.levels_start + len(
self.levels_to_dim
)
dim = DimEntry(self.levels_to_dim[level - self.levels_start])
bdim = torch._C._functorch.maybe_get_bdim(tensor)
assert bdim is not None
levels.insert(bdim, dim)
tensor = torch._C._functorch.get_unwrapped(tensor)
from . import Tensor
result = Tensor()
result._tensor = tensor
result._batchtensor = batchedtensor
result._has_device = has_device
result._levels = levels
return result
def inplace_update_layers(
self, batchtensor: torch.Tensor, levels: list[DimEntry]
) -> None:
"""
Update the levels of a batched tensor in place.
This requires the _maybe_unsafe_set_level binding that we'll add to functorch.
Args:
batchtensor: Batched tensor to update
levels: New levels to set
"""
# Check if tensor is batched
if not torch._C._functorch.is_batchedtensor(batchtensor):
return
impl = batchtensor
for i in reversed(range(len(self.levels_to_dim))):
if impl is None:
break
if any(l == DimEntry(self.levels_to_dim[i]) for l in levels):
# This is very interesting! The level on batch tensor is
# meaningless! We set it RIGHT before we go into vmap
torch._C._functorch._maybe_unsafe_set_level(impl, self.levels_start + i)
impl = torch._C._functorch.get_unwrapped(impl)

View File

@ -0,0 +1,561 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
from ._dim_entry import _match_levels, DimEntry
from ._tensor_info import TensorInfo
if TYPE_CHECKING:
from . import Dim
def _safe_index(lst: list, item: Any) -> Optional[int]:
"""
Helper function to find index of item in list.
For DimEntry objects, uses __eq__ comparison which properly handles
both positional and Dim entries.
Returns the index if found, None if not found.
"""
for i, list_item in enumerate(lst):
# Use == for DimEntry objects as they have proper __eq__ implementation
if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
if list_item == item:
return i
elif list_item is item:
return i
return None
@dataclass
class IndexingInfo:
can_call_original: bool = False
advanced_indexing: bool = False
self_tensor: Optional[torch.Tensor] = None
flat_inputs: list[Any] = field(default_factory=list)
result_levels: list[DimEntry] = field(default_factory=list)
has_device: bool = False
def has_dims(obj: Any) -> bool:
"""
Check if an object has first-class dimensions.
This function checks if the object is either a Dim or a functorch Tensor
that has first-class dimensions, using the proper check_exact methods.
"""
from . import Dim, Tensor
return Dim.check_exact(obj) or Tensor.check_exact(obj)
def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
"""
Bind dimensions to size and calculate proper strides for dim packs.
"""
from . import DimensionBindError
rhs_prod = 1
for i, dim in enumerate(dims):
if not dim.is_bound:
# Check for multiple unbound dimensions
for j in range(i + 1, len(dims)):
if not dims[j].is_bound:
raise DimensionBindError(
f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
)
rhs_prod *= dims[j].size
# Calculate the size for this unbound dimension
if sz % rhs_prod != 0:
tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
raise DimensionBindError(
f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
)
inferred_size = sz // rhs_prod
dim.size = inferred_size
rhs_prod = sz
break
else:
rhs_prod *= dim.size
# Final validation that dimensions match
if rhs_prod != sz:
tup = tuple(dims)
raise DimensionBindError(
f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
)
# Calculate new sizes and strides for each dimension in the pack
# First calculate all strides by iterating in reverse
new_strides = [0] * len(dims)
current_stride = sd
for i in reversed(range(len(dims))):
new_strides[i] = current_stride
current_stride *= dims[i].size
# Then append sizes and strides in forward order
for i in range(len(dims)):
nsz.append(dims[i].size)
nsd.append(new_strides[i])
def slice_to_tuple(flat_inputs: list) -> tuple:
return tuple(flat_inputs)
def extractIndices(index: Any, indices: list) -> bool:
if isinstance(index, tuple): # mpy::tuple_view::check
indices.extend(index)
return True
elif isinstance(index, torch.Tensor): # THPVariable_Check
indices.append(index)
return False
elif not hasattr(index, "__iter__") or isinstance(
index, (str, bytes)
): # !mpy::is_sequence
indices.append(index)
return False
# Handle sequence case (list)
if isinstance(index, list):
if len(index) >= 32:
indices.extend(index)
return True
# Check each item in the sequence
for item in index:
if (
isinstance(item, (torch.Tensor, slice))
or hasattr(item, "__iter__")
or item is ...
or item is None
or has_dims(item)
):
indices.extend(index)
return True
# If we got here, treat as single index
indices.append(index)
return False
# Default case
indices.append(index)
return False
def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
self = args[0]
index = args[1]
iinfo = getsetitem(self, index, has_dims(self))
if iinfo.can_call_original:
# Call original tensor __getitem__ directly, bypassing __torch_function__
return torch.Tensor.__getitem__(self, index)
return invoke_getitem(iinfo)
def setitem(self: Any, index: Any, rhs: Any) -> None:
"""Set values in tensor using first-class dimensions."""
from . import DimensionBindError, TensorInfo
iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))
if iinfo.can_call_original:
# Call original tensor __setitem__ directly, bypassing __torch_function__
torch._C.TensorBase.__setitem__(self, index, rhs)
return
# Handle RHS tensor with dimensions
rhs_info = TensorInfo.create(rhs, False, False)
if rhs_info:
# Check that rhs dimensions are compatible with result dimensions
for l in rhs_info.levels:
if not l.is_positional():
# Find this dimension in result levels
found = False
for result_level in iinfo.result_levels:
if (
not result_level.is_positional()
and result_level.dim() is l.dim()
):
found = True
break
if not found:
# Create tuple representation of result levels for error message
result_dims: list[Union[int, Dim]] = []
for rl in iinfo.result_levels:
if rl.is_positional():
result_dims.append(rl.position())
else:
result_dims.append(rl.dim())
raise DimensionBindError(
f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
f"({tuple(result_dims)!r})"
)
# Match RHS tensor to result levels
assert rhs_info.tensor is not None, "Cannot match levels on None tensor"
matched_rhs = _match_levels(
rhs_info.tensor, rhs_info.levels, iinfo.result_levels
)
else:
matched_rhs = rhs
# For advanced indexing with dimensions, we need special handling
if iinfo.advanced_indexing:
# Use advanced indexing - the flat_inputs already contain matched tensors
tup = slice_to_tuple(iinfo.flat_inputs)
if iinfo.self_tensor is None:
raise RuntimeError("Cannot setitem on None tensor")
torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
else:
# Simple copy operation
if iinfo.self_tensor is None:
raise RuntimeError("Cannot copy to None tensor")
iinfo.self_tensor.copy_(matched_rhs)
def invoke_getitem(iinfo: IndexingInfo) -> Any:
if iinfo.advanced_indexing:
self_tensor = iinfo.self_tensor
tup = slice_to_tuple(iinfo.flat_inputs)
if self_tensor is None:
raise RuntimeError("Cannot getitem on None tensor")
rtensor = self_tensor[tup]
else:
rtensor = iinfo.self_tensor # type: ignore[assignment]
if rtensor is None:
raise RuntimeError("Cannot getitem on None tensor")
# rtensor is now guaranteed to be not None
# Create a Tensor with the proper dimensions using the class method
from . import Tensor
return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)
def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
from . import DimList # Import DimList for type checking
can_call_original_getitem = not tensors_have_dims
input_list = []
if has_dims(index):
input_list.append(index)
else:
is_sequence = extractIndices(index, input_list)
# nothing about first class dims here, fallback to getitem
if can_call_original_getitem and not is_sequence:
return IndexingInfo(can_call_original=True)
# Calculate how many dimensions have been indexed in order to compute the
# size of ... or expand a potentially unbound dimension list.
dims_indexed = 0
expanding_object = -1
unbound_dim_list = None
dimlists = [] # Track DimList positions for later processing
def check_expanding(i: int) -> None:
nonlocal expanding_object
if expanding_object != -1:
from . import DimensionBindError
raise DimensionBindError(
f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
f"{expanding_object} and {i}"
)
expanding_object = i
def is_dimpack(s: Any) -> bool:
from . import Dim
return (
isinstance(s, (tuple, list))
and len(s) > 0
and all(Dim.check_exact(item) for item in s)
)
has_dimpacks_or_none = False
for i, s in enumerate(input_list):
if has_dims(s):
can_call_original_getitem = False
dims_indexed += 1
elif s is ...:
check_expanding(i)
elif isinstance(s, DimList):
can_call_original_getitem = False
if not s.is_bound:
check_expanding(i)
unbound_dim_list = s
else:
dims_indexed += len(s._dims)
dimlists.append(i)
elif s is None:
has_dimpacks_or_none = True
elif is_dimpack(s):
can_call_original_getitem = False
has_dimpacks_or_none = True
dims_indexed += 1
else:
dims_indexed += 1
# Early return if we can use original getitem
if can_call_original_getitem:
return IndexingInfo(can_call_original=True)
self_info = TensorInfo.create(self, False, True)
total_dims = len(self_info.levels) # Total dimensions (positional + named)
if dims_indexed > total_dims:
raise ValueError(
f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
)
# Expand any unbound dimension list, or expand ... into individual : slices.
expanding_dims = total_dims - dims_indexed
if expanding_object != -1:
if unbound_dim_list is not None:
# Bind unbound dimension list to the expanding dimensions
unbound_dim_list.bind_len(expanding_dims)
else:
# Expand ... into slice(None) objects
no_slices = [slice(None)] * expanding_dims
input_list = (
input_list[:expanding_object]
+ no_slices
+ input_list[expanding_object + 1 :]
)
# Flatten out any dimensions stored in dimlist elements directly into the inputs
# Process in reverse order to maintain indices
for i in range(len(dimlists) - 1, -1, -1):
idx = dimlists[i]
# We added more elements to input because of ...
# so we need to also adjust the index to get back to where the
# dimlist existed
if (
unbound_dim_list is None
and expanding_object != -1
and idx > expanding_object
):
idx += expanding_dims
dl = input_list[idx]
# PRIVATE here naughty
input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]
return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)
def getsetitem_flat(
self_info: TensorInfo,
input_list: list,
keys: list[DimEntry],
values: list,
has_dimpacks_or_none: bool,
) -> IndexingInfo:
from . import Dim
# Track dimension usage
seen_dims: list[Any] = []
seen_dims_nuses: list[int] = []
def add_dim(dim: Any) -> None:
# Use safe indexing to avoid triggering __torch_function__ on Dim objects
idx = _safe_index(seen_dims, dim)
if idx is not None:
seen_dims_nuses[idx] += 1
else:
seen_dims.append(dim)
seen_dims_nuses.append(1)
flat_inputs = []
tensor_inputs: list[Any] = []
device_holding_tensor = None
def append_flat_handle(handle: Any) -> None:
flat_inputs.append(handle)
tensor_inputs.append(None)
def append_tensor_input(ti: TensorInfo) -> None:
flat_inputs.append(None)
tensor_inputs.append(ti)
nonlocal device_holding_tensor
if ti.has_device and device_holding_tensor is None:
device_holding_tensor = ti.tensor
nsz = []
nsd = []
if self_info.tensor is None:
raise RuntimeError("Cannot get size/stride on None tensor")
sz = self_info.tensor.size()
sd = self_info.tensor.stride()
def append_size(i: int) -> None:
if has_dimpacks_or_none:
nsz.append(sz[i])
nsd.append(sd[i])
input_it = input_list[:]
def parse_nones() -> None:
nonlocal input_it
while input_it and input_it[0] is None:
append_flat_handle(slice(None))
nsz.append(1)
nsd.append(0)
input_it = input_it[1:]
def append_item(i: int, arg: Any) -> None:
if Dim.check_exact(arg):
d = arg
if d._size == -1:
d.size = sz[i]
add_dim(d)
append_size(i)
append_flat_handle(arg)
return
info = TensorInfo.create(arg, False, False)
if info:
append_size(i)
append_tensor_input(info)
for level in info.levels:
if not level.is_positional():
add_dim(level.dim())
return
if has_dimpacks_or_none:
if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
# dim pack
dim_pack = list(arg)
for d in dim_pack:
add_dim(d)
append_flat_handle(d)
_bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
return
append_size(i)
append_flat_handle(arg)
# Match indexing expressions with tensor dimensions
for i, level in enumerate(self_info.levels):
# Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
idx = _safe_index(keys, level)
if idx is not None:
append_item(i, values[idx])
else:
if level.is_positional():
parse_nones()
if not input_it:
append_flat_handle(slice(None))
append_size(i)
else:
arg = input_it[0]
input_it = input_it[1:]
append_item(i, arg)
else:
add_dim(level.dim())
append_flat_handle(level.dim())
append_size(i)
parse_nones()
# Restride tensor if needed
if has_dimpacks_or_none and nsz:
if self_info.tensor is None:
raise RuntimeError("Cannot restride None tensor")
self_tensor = self_info.tensor.as_strided(
nsz, nsd, self_info.tensor.storage_offset()
)
else:
self_tensor = self_info.tensor
# Determine result shape and indexing requirements
result_levels: list[Any] = []
index_levels = []
tensor_insert_point = -1
requires_getindex = False
def mark_tensor_index() -> None:
nonlocal tensor_insert_point
if tensor_insert_point == -1:
tensor_insert_point = len(result_levels)
elif tensor_insert_point != len(result_levels):
tensor_insert_point = 0
for i, inp in enumerate(flat_inputs):
if tensor_inputs[i] is not None:
requires_getindex = True
mark_tensor_index()
for level in tensor_inputs[i].levels:
if level not in index_levels:
index_levels.append(level)
elif Dim.check_exact(inp):
d = inp
# Use safe indexing to avoid triggering __torch_function__
dim_idx = _safe_index(seen_dims, d)
assert dim_idx is not None, f"Dim {d} not found in seen_dims"
if seen_dims_nuses[dim_idx] == 1:
flat_inputs[i] = slice(None)
result_levels.append(DimEntry(d))
else:
requires_getindex = True
flat_inputs[i] = None
tensor_inputs[i] = TensorInfo(
d._get_range(), [DimEntry(d)], False, None
)
if DimEntry(d) not in index_levels:
index_levels.append(DimEntry(d))
mark_tensor_index()
else:
if inp != slice(None):
requires_getindex = True
if not isinstance(inp, int):
result_levels.append(DimEntry(-1))
# Insert indexing dimensions at first tensor use point
if tensor_insert_point != -1:
for level in reversed(index_levels):
result_levels.insert(tensor_insert_point, level)
# Match tensors to indexing shape
if requires_getindex:
for i in range(len(flat_inputs)):
if tensor_inputs[i] is not None:
t = tensor_inputs[i].tensor
assert t is not None, "TensorInfo should have valid tensor data"
if (
not tensor_inputs[i].has_device
and device_holding_tensor is not None
):
t = t.to(device_holding_tensor.device)
flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)
# Number positional dimensions correctly
seen_positionals = 0
for i in reversed(range(len(result_levels))):
if result_levels[i].is_positional():
seen_positionals += 1
result_levels[i] = DimEntry(-seen_positionals)
return IndexingInfo(
can_call_original=False,
advanced_indexing=requires_getindex,
self_tensor=self_tensor,
flat_inputs=flat_inputs,
result_levels=result_levels,
has_device=self_info.has_device,
)

214
functorch/dim/_order.py Normal file
View File

@ -0,0 +1,214 @@
from __future__ import annotations
from typing import Any, TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Sequence
import torch # noqa: TC002
from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
def _wrap_dim(arg: Any, orig_ndim: int, allow_none: bool = True) -> DimEntry:
"""
Convert various dimension representations to DimEntry.
Args:
arg: The argument to convert (Dim, int, or other)
orig_ndim: Original number of dimensions
allow_none: Whether to allow None values
Returns:
DimEntry representation of the dimension
"""
from . import Dim
if arg is None and allow_none:
return DimEntry() # None entry
elif isinstance(arg, Dim):
return DimEntry(arg)
elif isinstance(arg, int):
if arg < 0:
pos = arg
else:
pos = arg - orig_ndim
return DimEntry(pos)
else:
return DimEntry()
def order(
tensor_or_dim: Union[torch.Tensor, Any], *dims: Union[Any, Sequence[Any]]
) -> torch.Tensor:
"""
Reorder the dimensions of a tensor or create a tensor from a dimension.
It allows reordering tensor dimensions using first-class dimensions and
positional indices.
Args:
tensor_or_dim: Input tensor with first-class dimensions, or a Dim object
*dims: Dimensions or sequences of dimensions specifying the new order
Returns:
Tensor with reordered dimensions
Examples:
>>> import torch
>>> from functorch.dim import dims
>>> batch, channel, height, width = dims(4)
>>> x = torch.randn(2, 3, 4, 5)[batch, channel, height, width]
>>> # Reorder to [height, width, batch, channel]
>>> y = order(x, height, width, batch, channel)
"""
from . import Dim, DimList, Tensor
# Handle first argument - tensor or dimension
if isinstance(tensor_or_dim, Tensor):
# First-class tensor
orig_levels = tensor_or_dim._levels[:]
data = tensor_or_dim._tensor
has_device = tensor_or_dim._has_device
elif isinstance(tensor_or_dim, Dim):
# Single dimension - create range tensor
orig_levels = [DimEntry(tensor_or_dim)]
data = tensor_or_dim._get_range()
has_device = False
else:
raise ValueError("First argument must be a Tensor or Dim object")
flat_positional_dims = []
to_flatten = [] # List of (start_index, length) pairs for flattening
levels = orig_levels[:]
orig_ndim = ndim_of_levels(levels)
def append_dim(d: DimEntry) -> None:
"""Add a dimension to the reordering, removing it from available levels."""
try:
idx = levels.index(d)
except ValueError:
idx = None
if idx is None:
if d.is_positional():
raise ValueError(
f"tensor has {orig_ndim} positional dimensions, but {d.position() + orig_ndim} specified, "
f"or it was specified twice"
)
else:
raise ValueError(
f"tensor does not contain dim {d.dim()} or it was specified twice"
)
levels[idx] = DimEntry()
flat_positional_dims.append(d)
n_new_positional = 0
# Process each dimension argument
for arg in dims:
entry = _wrap_dim(arg, orig_ndim, False)
if not entry.is_none():
append_dim(entry)
n_new_positional += 1
elif isinstance(arg, DimList):
# Handle DimList
for dim in arg._dims:
append_dim(DimEntry(dim))
n_new_positional += 1
else:
# Handle sequences of dimensions for flattening
n_new_positional += 1
if not hasattr(arg, "__iter__"):
raise ValueError("expected a Dim, List[Dim], or Sequence[Dim]")
# Convert to list to get length
seq = list(arg)
to_flatten.append((len(flat_positional_dims), len(seq)))
for item in seq:
entry = _wrap_dim(item, orig_ndim, False)
if entry.is_none():
raise ValueError("expected a Dim or int")
append_dim(entry)
# Build new level ordering
insert_point = -1
new_levels: list[DimEntry] = []
# Add remaining (non-reordered) levels, finding insertion point for new dimensions
for level in levels:
if level.is_none():
continue
if level.is_positional():
if insert_point == -1:
insert_point = len(new_levels)
new_levels.extend(flat_positional_dims)
new_levels.append(level)
# If no positional dimensions found, append new dims at the end
if insert_point == -1:
insert_point = len(new_levels)
new_levels.extend(flat_positional_dims)
# Match tensor to new level structure
assert data is not None, "Cannot reorder None tensor"
ndata = _match_levels(data, orig_levels, new_levels)
# Handle dimension flattening if requested
if to_flatten:
# Now build the reshape target
view_shape = []
sizes = ndata.size()
# Add dimensions before the reordered ones
for i in range(insert_point):
view_shape.append(sizes[i])
# Process flattening groups
i = 0
for start_idx, length in to_flatten:
# Add individual dims before this flattening group
while i < start_idx:
view_shape.append(sizes[insert_point + i])
i += 1
# Flatten the group
new_size = 1
for j in range(length):
new_size *= sizes[insert_point + i + j]
view_shape.append(new_size)
i += length
# Add remaining individual dims
while i < len(flat_positional_dims):
view_shape.append(sizes[insert_point + i])
i += 1
# Add dimensions after the reordered ones
for i in range(insert_point + len(flat_positional_dims), len(levels)):
view_shape.append(sizes[i])
# Update levels by removing flattened dimensions
n_to_remove = len(flat_positional_dims) - n_new_positional
if n_to_remove > 0:
# Remove flattened levels
new_levels = (
new_levels[:insert_point] + new_levels[insert_point + n_to_remove :]
)
ndata = ndata.reshape(view_shape)
# Renumber positional dimensions (negative indexing from the right)
seen = 0
for i in range(len(new_levels) - 1, -1, -1):
if new_levels[i].is_positional() or (
i >= insert_point and i < insert_point + n_new_positional
):
seen -= 1
new_levels[i] = DimEntry(seen)
result = Tensor.from_positional(ndata, new_levels, has_device)
return result # type: ignore[return-value]

View File

@ -0,0 +1,67 @@
import dis
from typing import Any, Optional
class _PyInstDecoder:
"""
Decodes Python bytecode instructions to extract variable names
"""
def __init__(self, code_object: Any, lasti: int) -> None:
self.code_object = code_object
self.instructions = list(dis.get_instructions(code_object))
self.offset = self._find_instruction_index(lasti)
def _find_instruction_index(self, lasti: int) -> int:
"""Find instruction index corresponding to lasti (byte offset)."""
# Find the instruction at or before lasti
# This should find the CALL instruction, not the next one
best_idx = 0
for i, instr in enumerate(self.instructions):
if instr.offset <= lasti:
best_idx = i
else:
break
return best_idx
def next(self) -> None:
"""Advance to the next instruction."""
self.offset += 1
def opcode(self) -> Optional[str]:
"""Get the opcode name of the current instruction."""
if self.offset < len(self.instructions):
return self.instructions[self.offset].opname
return None
def oparg(self) -> int:
"""Get the argument of the current instruction."""
if self.offset < len(self.instructions):
return self.instructions[self.offset].arg or 0
return 0
def name(self) -> Optional[str]:
"""
Extract variable name from current instruction.
"""
opname = self.opcode()
if not opname:
return None
names = None
if opname in ("STORE_NAME", "STORE_GLOBAL"):
names = self.code_object.co_names
elif opname == "STORE_FAST":
names = self.code_object.co_varnames
elif opname == "STORE_DEREF":
names = self.code_object.co_cellvars
if not names:
names = self.code_object.co_freevars
else:
return None
arg = self.oparg()
if names and 0 <= arg < len(names):
return names[arg]
return None

View File

@ -0,0 +1,68 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from ._dim_entry import DimEntry
@dataclass
class TensorInfo:
tensor: Optional[torch.Tensor]
levels: list[DimEntry]
has_device: bool
batchedtensor: Optional[torch.Tensor]
def __post_init__(self) -> None:
from ._dim_entry import DimEntry
assert all(isinstance(l, DimEntry) for l in self.levels)
def ndim(self) -> int:
from ._dim_entry import ndim_of_levels
return ndim_of_levels(self.levels)
def __bool__(self) -> bool:
return self.tensor is not None
@staticmethod
def create(
h: Any, ensure_batched: bool = True, ensure_present: bool = True
) -> TensorInfo:
from . import Dim, DimEntry, Tensor
if Tensor.check_exact(h):
# functorch Tensor with first-class dimensions
return TensorInfo(
h._get_tensor(),
h._get_levels(),
h._get_has_device(),
h._get_batchtensor() if ensure_batched else None,
)
elif Dim.check_exact(h):
# For Dim objects, only get range/batchtensor if needed and dimension is bound
tensor = h._get_range() if h.is_bound else None
batchtensor = (
h._get_batchtensor() if ensure_batched and h.is_bound else None
)
return TensorInfo(
tensor,
[DimEntry(h)],
False,
batchtensor,
)
elif isinstance(h, torch.Tensor):
# Plain torch tensor - create positional levels
levels = []
for i in range(-h.dim(), 0):
levels.append(DimEntry(i))
return TensorInfo(h, levels, True, h)
else:
if ensure_present:
raise ValueError("expected a tensor object")
return TensorInfo(None, [], False, None)

263
functorch/dim/_wrap.py Normal file
View File

@ -0,0 +1,263 @@
"""
Python implementation of function wrapping functionality for functorch.dim.
"""
from __future__ import annotations
import functools
from typing import Any, Callable, Optional
import torch
from torch.utils._pytree import tree_map
from ._dim_entry import DimEntry
from ._enable_all_layers import EnableAllLayers
from ._tensor_info import TensorInfo
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Handle tensor conversion for torch function integration."""
return tensor
class WrappedOperator:
"""
This class wraps PyTorch operations to support first-class dimensions.
"""
def __init__(
self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
):
self.orig = orig
self.wrapper_implementation = wrapper_implementation
self.name = getattr(orig, "__name__", "")
self.doc = getattr(orig, "__doc__", None)
self.dim_name = dim_name
self.is_pointwise = False
self.dim_offset = 0
self.keepdim_offset = 1
self.single_dim = False
self.reduce = True
# Update docstring if we have a dim_name
if self.doc and self.dim_name:
self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
def function(self) -> Callable:
"""Create a wrapped function that calls our wrapper implementation."""
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
return self.wrapper_implementation(self, *args, **kwargs)
# Copy metadata using functools.update_wrapper for just __name__ and __doc__
functools.update_wrapper(
wrapped_func, self.orig, assigned=("__name__",), updated=()
)
wrapped_func.__doc__ = self.doc
return wrapped_func
def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
"""Convert single dimension specification to DimEntry object."""
from . import Dim
if isinstance(dim, Dim):
if keepdim:
raise ValueError("cannot preserve first-class dimensions with keepdim=True")
return DimEntry(dim)
elif isinstance(dim, int):
i = dim
while i >= 0:
i -= ndim
return DimEntry(i)
else:
return DimEntry()
def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
"""Convert dimension specification to list of DimEntry objects."""
de = _wrap_dim(dim, ndim, keepdim)
result = []
if not de.is_none():
result.append(de)
else:
for d in dim:
result.append(_wrap_dim(d, ndim, keepdim))
return result
def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
"""
This is the core method that handles dimension-aware operations.
"""
if not args:
raise ValueError("Expected at least one argument (self)")
# Get dimension argument
dim_arg = kwargs.get(wrapper.dim_name)
if dim_arg is None and wrapper.dim_offset < len(args):
# Try to get dim from positional args (accounting for self at index 0)
dim_idx = wrapper.dim_offset + 1
if dim_idx < len(args):
dim_arg = args[dim_idx]
# If no dimension argument provided, fall back to standard functorch handling
if dim_arg is None:
info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
if not info:
return wrapper.orig(*args, **kwargs)
with EnableAllLayers(info.levels) as guard:
assert info.batchedtensor is not None
guard.inplace_update_layers(info.batchedtensor, info.levels)
new_args = list(args)
new_args[0] = handle_from_tensor(info.batchedtensor)
result = wrapper.orig(*new_args, **kwargs)
return guard.from_batched(result, info.has_device)
# Handle dimension-aware operation
info = TensorInfo.create(args[0])
if not info:
return wrapper.orig(*args, **kwargs)
# Check for keepdim parameter
keepdim = False
if wrapper.reduce:
keepdim_arg = kwargs.get("keepdim")
if keepdim_arg is None and wrapper.keepdim_offset < len(args):
keepdim_idx = wrapper.keepdim_offset + 1
if keepdim_idx < len(args):
keepdim_arg = args[keepdim_idx]
if keepdim_arg is not None:
keepdim = bool(keepdim_arg)
# Wrap dimensions
ndim = info.ndim()
dims = _wrap_dims(dim_arg, ndim, keepdim)
# Convert dimensions to indices and validate
dim_indices: list[int] = []
seen = [False] * len(info.levels)
for d in dims:
midx = None
for i, level in enumerate(info.levels):
if level == d:
midx = i
break
if midx is None:
# Try to match by position/name more flexibly
for i, level in enumerate(info.levels):
if hasattr(level, "matches") and level.matches(d):
midx = i
break
if midx is None:
level_strs = [str(level) for level in info.levels]
raise ValueError(
f"Tensor with dimensions {level_strs} does not contain {d}"
)
seen[midx] = True
dim_indices.append(midx)
# Determine new levels after reduction
new_levels = []
if wrapper.reduce and not keepdim:
for i, level in enumerate(info.levels):
if not seen[i]:
new_levels.append(level)
else:
new_levels = info.levels[:]
# Create dimension indices for the original function
if len(dim_indices) == 1:
py_indices: Any = dim_indices[0]
else:
py_indices = tuple(dim_indices)
# Update arguments
new_args = list(args)
new_kwargs = kwargs.copy()
assert info.tensor is not None
new_args[0] = handle_from_tensor(info.tensor)
# Update dimension argument
if wrapper.dim_name in new_kwargs:
new_kwargs[wrapper.dim_name] = py_indices
else:
dim_idx = wrapper.dim_offset + 1
if dim_idx < len(new_args):
new_args = list(new_args)
new_args[dim_idx] = py_indices
# Call original function
result = wrapper.orig(*new_args, **new_kwargs)
# Wrap results
def wrap_result(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
from . import Tensor
return Tensor.from_positional(obj, new_levels, info.has_device)
return obj
return tree_map(wrap_result, result)
def _wrap(
orig: Callable,
dim_offset: Optional[int] = None,
keepdim_offset: Optional[int] = None,
dim_name: Optional[str] = None,
single_dim: Optional[bool] = None,
reduce: Optional[bool] = None,
) -> Callable:
"""
Wrap a PyTorch function to support first-class dimensions.
Args:
orig: Original function to wrap
dim_offset: Offset for dimension argument (default: 0)
keepdim_offset: Offset for keepdim argument (default: 1)
dim_name: Name of dimension parameter (default: "dim")
single_dim: Whether function takes single dimension (default: False)
reduce: Whether function reduces dimensions (default: True)
"""
dim_name = dim_name or "dim"
wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
if dim_offset is not None:
wrapper.dim_offset = dim_offset
if keepdim_offset is not None:
wrapper.keepdim_offset = keepdim_offset
if single_dim is not None:
wrapper.single_dim = single_dim
if reduce is not None:
wrapper.reduce = reduce
return wrapper.function()
def call_torch_function(
wrapper: WrappedOperator,
func: Callable,
types: tuple,
args: tuple = (),
kwargs: Optional[dict] = None,
) -> Any:
"""
Handle __torch_function__ calls for wrapped operators.
"""
if kwargs is None:
kwargs = {}
# Import here to avoid circular imports
from . import _Tensor
# Use the torch function mechanism from _Tensor
return _Tensor.__torch_function__(func, types, args, kwargs)

View File

@ -6,11 +6,14 @@
import os
import signal
import subprocess
from collections.abc import Generator
from contextlib import contextmanager
@contextmanager
def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
def magic_trace(
output: str = "trace.fxt", magic_trace_cache: str = "/tmp/magic-trace"
) -> Generator[None, None, None]:
pid = os.getpid()
if not os.path.exists(magic_trace_cache):
print(f"Downloading magic_trace to: {magic_trace_cache}")
@ -26,6 +29,7 @@ def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
subprocess.run(["chmod", "+x", magic_trace_cache])
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
assert p.stderr is not None
while True:
x = p.stderr.readline()
print(x)
@ -36,7 +40,8 @@ def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
finally:
p.send_signal(signal.SIGINT)
r = p.wait()
print(p.stderr.read())
p.stderr.close()
if p.stderr is not None:
print(p.stderr.read())
p.stderr.close()
if r != 0:
raise ValueError(f"magic_trace exited abnormally: {r}")

View File

@ -1,15 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functorch._C import dim
tree_flatten = dim.tree_flatten
def tree_map(fn, tree):
vs, unflatten = tree_flatten(tree)
return unflatten(fn(v) for v in vs)

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
from types import (
BuiltinMethodType,
FunctionType,
@ -11,11 +12,8 @@ from types import (
MethodDescriptorType,
WrapperDescriptorType,
)
from typing import Any, Callable
from functorch._C import dim as _C
_wrap_method = _C._wrap_method
FUNC_TYPES = (
FunctionType,
@ -26,14 +24,24 @@ FUNC_TYPES = (
PROPERTY_TYPES = (GetSetDescriptorType, property)
def wrap_type(to_patch, pattern, __torch_function__):
wrap_method = _wrap_method
def _py_wrap_method(orig: Callable, __torch_function__: Callable) -> Callable:
def impl(*args: Any, **kwargs: Any) -> Any:
return __torch_function__(orig, None, args, kwargs)
all = {}
# Copy metadata using functools.update_wrapper for just __name__ and __doc__
functools.update_wrapper(impl, orig, assigned=("__name__", "__doc__"), updated=())
return impl
def wrap_type(to_patch: Any, pattern: type, __torch_function__: Callable) -> None:
wrap_method = _py_wrap_method
all: dict[str, Any] = {}
for t in reversed(pattern.mro()[:-1]): # skip object
all.update(t.__dict__)
def wrap_attr(orig):
def wrap_attr(orig: Any) -> property:
return property(wrap_method(orig.__get__, __torch_function__))
for name, obj in all.items():

View File

@ -4,7 +4,7 @@ import functools
from typing import Callable, TYPE_CHECKING, Union
import torch
from functorch._C import dim as _C
from functorch.dim import dims # noqa: F401
from ._parsing import (
_ellipsis,
@ -20,8 +20,6 @@ if TYPE_CHECKING:
__all__ = ["rearrange"]
dims = _C.dims
@functools.lru_cache(256)
def _create_rearrange_callable(