Files
pytorch/functorch/dim/__init__.py
Yuanyuan Chen b2953f5643 [9/N] Apply ruff UP035 rule (#165515)
This is follow-up of #165214 to continue applying ruff UP035 rule to the code base.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165515
Approved by: https://github.com/Lucaskabela
2025-10-17 00:09:51 +00:00

1593 lines
52 KiB
Python

from __future__ import annotations
import dis
import inspect
import sys
from typing import Any, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
import torch
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
from ._enable_all_layers import EnableAllLayers
from ._py_inst_decoder import _PyInstDecoder
from ._tensor_info import TensorInfo
POINTWISE_OPTIMIZE = True
DOT_OPTIMIZED = True
# Global dimension level counter
_n_dims_created = 0
def _relevant_op(opcode: Optional[str]) -> bool:
"""Check if opcode is relevant for variable assignment."""
return bool(opcode and opcode.startswith("STORE_"))
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Handle tensor conversion for torch function integration."""
return tensor
def _create_dim(name: str, size: Optional[int] = None) -> Dim:
"""Create a new Dim object."""
return Dim(name, size if size is not None else -1)
def dims(
n: Optional[int] = None, sizes: Optional[list[Optional[int]]] = None
) -> Union[Dim, tuple[Dim, ...]]:
"""
Create and return one or more Dim objects.
Uses bytecode inspection to determine variable names when possible.
Args:
n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified.
sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be
created, specifying each dimensions size, or None to leave the size unset.
Returns:
Union[Dim, Tuple[Dim, ...]]: Single Dim if n=1, tuple of Dims otherwise.
Examples:
>>> batch, channel, width, height = dims(4)
>>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224])
>>> single_dim = dims(1)
"""
specified_ndims = -1
found_ndims = 0
# Parse arguments
if sizes is not None:
specified_ndims = len(sizes)
if n is not None:
specified_ndims = n
# Use bytecode inspection
frame = inspect.currentframe()
if frame is None:
raise RuntimeError("Unable to get current frame")
frame = frame.f_back
try:
if frame is None:
raise RuntimeError("Unable to get caller frame")
code = frame.f_code
lasti = frame.f_lasti
decoder = _PyInstDecoder(code, lasti)
if sys.version_info >= (3, 11):
if decoder.opcode() == "PRECALL":
decoder.next()
# Move to next instruction after the call
decoder.next()
# Determine number of dimensions from bytecode
if _relevant_op(decoder.opcode()):
found_ndims = 1
elif decoder.opcode() == "UNPACK_SEQUENCE":
found_ndims = decoder.oparg()
decoder.next() # Move past UNPACK_SEQUENCE
if specified_ndims == -1:
if found_ndims == 0:
raise SyntaxError(
"dims() must be assigned to a sequence of variable names or have argument n specified"
)
specified_ndims = found_ndims
if found_ndims != specified_ndims:
found_ndims = 0
def genobject(i: int) -> Dim:
nonlocal found_ndims
name = None
if i < found_ndims:
name = decoder.name()
if not name:
name = f"d{i}"
found_ndims = 0
else:
decoder.next() # Move to next STORE instruction
size = sizes[i] if sizes is not None else None
return _create_dim(name, size)
# Validate sizes parameter
if sizes is not None and len(sizes) != specified_ndims:
raise ValueError(f"expected {specified_ndims} sizes but found {len(sizes)}")
if specified_ndims == 1:
return genobject(0)
result = []
for i in range(specified_ndims):
result.append(genobject(i))
return tuple(result)
finally:
del frame
class DimList:
"""
A list of first-class dimensions that can be bound to tensor dimensions.
A DimList can be in one of two states:
1. Unbound: Created with just a name, no specific dimensions yet
2. Bound: Either created with specific dimensions/sizes, or bound later via bind() or bind_len()
"""
_name: Optional[str]
_dims: list[Dim]
_bound: bool
def __init__(
self,
len_or_dims: Optional[Union[int, Sequence]] = None,
name: Optional[str] = None,
):
"""
Initialize a new DimList object.
Args:
len_or_dims: Optional length (int) or sequence of dimensions/sizes
name: Optional name for the dimension list
"""
# Initialize attributes
self._name = name
self._dims: list = []
self._bound = False
if isinstance(len_or_dims, int):
self.bind_len(len_or_dims)
elif len_or_dims is not None:
dims = []
for i, item in enumerate(len_or_dims):
if isinstance(item, int):
dim_name = f"{self._name}{i}" if self._name else f"dim{i}"
dims.append(Dim(dim_name, item))
else:
dims.append(Dim(item))
self._set_dims(dims)
def _set_dims(self, dims: list) -> None:
"""Set the dimensions and mark as bound."""
self._bound = True
self._dims = dims
def bind_len(self, size: int) -> None:
"""
Bind this DimList to a specific length.
Args:
size: Number of dimensions to bind to
Raises:
DimensionBindError: If already bound to a different size
"""
if self._bound:
if len(self._dims) != size:
raise DimensionBindError(
f"Dimlist has size {len(self._dims)} but it is being bound to size {size}"
)
else:
self._bound = True
self._dims = []
for i in range(size):
dim_name = f"{self._name}{i}" if self._name else f"dim{i}"
self._dims.append(Dim(dim_name))
def bind(self, sizes: Sequence[int]) -> None:
"""
Bind this DimList to specific sizes.
Args:
sizes: Sequence of sizes for each dimension
Raises:
ValueError: If sizes is not a sequence
"""
if not hasattr(sizes, "__len__") or not hasattr(sizes, "__getitem__"):
raise ValueError("expected a sequence")
size = len(sizes)
self.bind_len(size)
for i, dim_size in enumerate(sizes):
self._dims[i].size = int(dim_size)
def _size(self) -> int:
if not self._bound:
raise DimensionBindError("DimList not bound")
return len(self._dims)
def size(self) -> int:
"""Return the size (number of dimensions) of this DimList."""
return self._size()
def _set_bound(self, b: bool) -> None:
"""Set the bound status (for internal use)."""
self._bound = b
@property
def is_bound(self) -> bool:
"""Property to check if DimList is bound."""
return self._bound
def __len__(self) -> int:
"""Return the length of the DimList."""
return self.size()
def __getitem__(self, key: Union[int, slice]) -> Union[Dim, tuple[Dim, ...]]:
if not self._bound:
raise DimensionBindError("DimList not bound")
if isinstance(key, int):
if key < 0 or key >= len(self._dims):
raise IndexError("index out of bounds")
return self._dims[key]
elif isinstance(key, slice):
start, stop, step = key.indices(len(self._dims))
result = []
for i in range(start, stop, step):
result.append(self._dims[i])
return tuple(result)
else:
raise ValueError("expected an int or a slice")
def __repr__(self) -> str:
"""Return string representation of the DimList."""
if self._bound:
# Show as tuple representation
return f"({', '.join(repr(dim) for dim in self._dims)})"
elif self._name is not None:
# Show as *name for unbound with name
return f"*{self._name}"
else:
# Show as <unbound_dimlist> for unbound without name
return "<unbound_dimlist>"
def __str__(self) -> str:
"""Return string representation of the DimList."""
return self.__repr__()
@classmethod
def __torch_function__(
cls,
func: Callable,
types: tuple,
args: tuple = (),
kwargs: Optional[dict] = None,
) -> Any:
return _Tensor.__torch_function__(func, types, args, kwargs)
def _create_dimlist(
name: str, size: Optional[Union[int, list[Optional[int]]]] = None
) -> DimList:
"""Create a DimList object with the given name and optional size."""
dimlist = DimList(name=name)
if size is not None:
if isinstance(size, int):
dimlist.bind_len(size)
else:
# size is a list of optional ints
dimlist.bind_len(len(size))
for i, s in enumerate(size):
if s is not None:
dimlist._dims[i].size = s
return dimlist
def dimlists(
n: Optional[int] = None, sizes: Optional[list[Optional[int]]] = None
) -> Union[DimList, tuple[DimList, ...]]:
"""
Create and return one or more DimList objects.
Similar to dims() but creates DimList objects instead.
"""
specified_ndims = -1
found_ndims = 0
# Parse arguments
if sizes is not None:
specified_ndims = len(sizes)
if n is not None:
specified_ndims = n
frame = inspect.currentframe()
if frame is None:
raise RuntimeError("Unable to get current frame")
frame = frame.f_back
try:
if frame is None:
raise RuntimeError("Unable to get caller frame")
code = frame.f_code
lasti = frame.f_lasti
decoder = _PyInstDecoder(code, lasti)
if sys.version_info >= (3, 11):
if decoder.opcode() == "PRECALL":
decoder.next()
# Move to next instruction after the call
decoder.next()
# Determine number of dimensions from bytecode
if _relevant_op(decoder.opcode()):
found_ndims = 1
elif decoder.opcode() == "UNPACK_SEQUENCE":
found_ndims = decoder.oparg()
decoder.next() # Move past UNPACK_SEQUENCE
if specified_ndims == -1:
if found_ndims == 0:
raise SyntaxError(
"dimlists() must be assigned to a sequence of variable names or have argument n specified"
)
specified_ndims = found_ndims
if found_ndims != specified_ndims:
found_ndims = 0
# Generator function for dimlist names
def genobject(i: int) -> str:
nonlocal found_ndims
name = None
if i < found_ndims:
name = decoder.name()
if not name:
name = f"d{i}"
found_ndims = 0
else:
decoder.next() # Move to next STORE instruction
return name
# Validate sizes
if sizes is not None and len(sizes) != specified_ndims:
raise ValueError(f"expected {specified_ndims} sizes but found {len(sizes)}")
# Create dimlists
if specified_ndims == 1:
name = genobject(0)
return _create_dimlist(name, sizes[0] if sizes is not None else None)
result = []
for i in range(specified_ndims):
name = genobject(i)
size = sizes[i] if sizes is not None else None
result.append(_create_dimlist(name, size))
return tuple(result)
finally:
del frame
class DimensionMismatchError(Exception):
pass
class DimensionBindError(Exception):
pass
from . import op_properties
def _safe_print(*args: Any, **kwargs: Any) -> None:
"""Safe print that avoids recursive torch function dispatches."""
import sys
# Convert any torch objects to basic representations
safe_args = []
for arg in args:
if hasattr(arg, "__class__") and "torch" in str(type(arg)):
safe_args.append(f"<{type(arg).__name__}>")
else:
safe_args.append(str(arg))
print(*safe_args, **kwargs, file=sys.stderr)
class _Tensor:
def _get_levels(self) -> list[Any]:
raise NotImplementedError("_get_levels must be implemented by subclass")
def _get_tensor(self) -> Optional[torch.Tensor]:
raise NotImplementedError("_get_tensor must be implemented by subclass")
@property
def ndim(self) -> int:
raise NotImplementedError("ndim must be implemented by subclass")
@property
def dims(self) -> tuple[Any, ...]:
return tuple(l.dim() for l in self._get_levels() if not l.is_positional())
def dim(self) -> int:
return self.ndim
@classmethod
def __torch_function__(
cls,
func: Callable,
types: tuple,
args: tuple = (),
kwargs: Optional[dict] = None,
) -> Any:
if kwargs is None:
kwargs = {}
if DOT_OPTIMIZED and func is torch.Tensor.__mul__:
# Check conditions: 2 args, both are tensor-like, both 0-dimensional
if (
len(args) == 2
and not kwargs
and isinstance(args[0], (_Tensor, torch.Tensor))
and isinstance(args[1], (_Tensor, torch.Tensor))
):
# Get tensor info for both operands
lhs_info = TensorInfo.create(
args[0], ensure_batched=False, ensure_present=False
)
rhs_info = TensorInfo.create(
args[1], ensure_batched=False, ensure_present=False
)
if (
lhs_info
and rhs_info
and lhs_info.tensor is not None
and rhs_info.tensor is not None
and lhs_info.tensor.dim() == 0
and rhs_info.tensor.dim() == 0
):
if (
lhs_info.tensor.is_floating_point()
and rhs_info.tensor.is_floating_point()
):
# Collect all unique levels and has_device
has_device = lhs_info.has_device or rhs_info.has_device
levels = []
for level in lhs_info.levels:
if level not in levels:
levels.append(level)
for level in rhs_info.levels:
if level not in levels:
levels.append(level)
# Debug print
# print(f"DEBUG: Creating delayed mul, levels: {levels}, has_device: {has_device}")
# Create delayed tensor
return Tensor.create_delayed(func, args, levels, has_device)
if func is torch.Tensor.__getitem__:
from functorch.dim._getsetitem import getitem
return getitem(cls, func, types, args, kwargs)
if func is torch.Tensor.__setitem__:
from functorch.dim._getsetitem import setitem
# args should be (tensor, index, value)
if len(args) == 3:
setitem(args[0], args[1], args[2])
return None
else:
raise ValueError(f"Expected 3 args for __setitem__, got {len(args)}")
# Fast-path for len; mostly to avoid infinite loop in TestMinFunctorchOnly.test_softmax_split
if func is torch.Tensor.__len__:
return args[0].size(0)
# Special handling for torch.softmax - use the pre-wrapped version
if func is torch.softmax:
return softmax(*args, **kwargs)
# Special handling for torch.stack - use the custom stack function
if func is torch.stack:
return stack(*args, **kwargs)
if (
func is torch.Tensor.split
or func is torch._VF.split # type: ignore[attr-defined]
or func is torch._VF.split_with_sizes # type: ignore[attr-defined]
or func is torch.split
):
return split(*args, **kwargs)
return _Tensor._torch_function_fallback(func, types, args, kwargs)
@staticmethod
def _torch_function_fallback(
func: Callable, types: tuple, args: tuple, kwargs: dict
) -> Any:
"""Fallback torch function implementation for non-special-cased functions."""
is_pointwise = POINTWISE_OPTIMIZE and func in op_properties.pointwise
# TODO: optimize pytree here
flat_args, spec = tree_flatten((args, kwargs))
device_holding_tensor = None
infos: list[TensorInfo] = []
result_levels: list[DimEntry] = []
for f in flat_args:
info = TensorInfo.create(f, not is_pointwise, False)
infos.append(info)
if info:
assert is_pointwise or info.batchedtensor is not None
if device_holding_tensor is None and info.has_device:
device_holding_tensor = info.tensor
# Collect all unique levels
for level in info.levels:
assert isinstance(level, DimEntry)
if level not in result_levels:
result_levels.append(level)
if is_pointwise:
# Pointwise operation: match all tensors to common levels
for i, info in enumerate(infos):
if info and info.tensor is not None:
tensor = info.tensor
if device_holding_tensor is not None and not info.has_device:
tensor = tensor.to(device_holding_tensor.device)
ml = _match_levels(tensor, info.levels, result_levels)
flat_args[i] = handle_from_tensor(ml)
unflat_args, unflat_kwargs = tree_unflatten(flat_args, spec)
result = func(*unflat_args, **unflat_kwargs)
# Wrap tensor results
def wrap_tensor(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return Tensor.from_positional(
obj, result_levels, device_holding_tensor is not None
)
return obj
# Small fastpath
if isinstance(result, torch.Tensor):
return wrap_tensor(result)
else:
return tree_map(wrap_tensor, result)
# Non-pointwise operation: use functorch vmap layers
with EnableAllLayers(result_levels) as guard:
# Update arguments with batched tensors
for i, info in enumerate(infos):
if info and info.batchedtensor is not None:
batched = info.batchedtensor
if device_holding_tensor is not None and not info.has_device:
batched = batched.to(device_holding_tensor.device)
guard.inplace_update_layers(batched, info.levels)
flat_args[i] = handle_from_tensor(batched)
unflat_args, unflat_kwargs = tree_unflatten(flat_args, spec)
result = func(*unflat_args, **unflat_kwargs)
# Unwrap results from functorch layers
def unwrap_tensor(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return guard.from_batched(obj, device_holding_tensor is not None)
return obj
if isinstance(result, torch.Tensor):
return unwrap_tensor(result)
else:
return tree_map(unwrap_tensor, result)
def __setitem__(self, index: Any, value: Any) -> None:
"""Set values in tensor using first-class dimensions."""
from functorch.dim._getsetitem import setitem
return setitem(self, index, value)
# expand and index are OK to be methods because they don't have torch.*
# versions, but if they did they need the stack/cat treatment
def expand(self, *args: Dim) -> _Tensor:
"""
Expand tensor by adding new dimensions or expanding existing dimensions.
If all arguments are Dim objects, adds new named dimensions.
Otherwise, falls back to regular tensor expansion behavior.
Args:
args: Either Dim objects for new dimensions or sizes for regular expansion
Returns:
New tensor with expanded dimensions
Example:
>>> i, j = dims()
>>> t = torch.randn(3, 4)
>>> expanded = t[i].expand(j, k) # Add j, k dimensions
>>> expanded2 = t[i].expand(2, 4) # Regular expand with sizes
"""
info = TensorInfo.create(self, ensure_batched=False, ensure_present=False)
for arg in args:
if not isinstance(arg, Dim):
# Not all args are Dims, fallback to regular expand
if isinstance(self, torch.Tensor) and not isinstance(self, _Tensor):
return torch.Tensor.expand(self, *args)
else:
return self.__torch_function__(
torch.Tensor.expand, (type(self),), (self,) + args
)
# All args are Dim objects - proceed with first-class dimension expansion
if not info:
# No tensor info available, fallback
return self.__torch_function__(
torch.Tensor.expand, (type(self),), (self,) + args
)
# First-class dimension expansion - all args are Dim objects
data = info.tensor
if data is None:
# No tensor data available, fallback
return self.__torch_function__(
torch.Tensor.expand, (type(self),), (self,) + args
)
levels = info.levels
new_levels: list[DimEntry] = []
new_sizes = []
new_strides = []
for d in args:
# Check if dimension already exists in current levels or new_levels
for level in levels:
if not level.is_positional() and level.dim() is d:
raise DimensionBindError(
f"expanding dimension {d} already exists in tensor with dims"
)
for new_level in new_levels:
if not new_level.is_positional() and new_level.dim() is d:
raise DimensionBindError(
f"expanding dimension {d} already exists in tensor with dims"
)
new_levels.append(DimEntry(d))
new_sizes.append(d.size)
new_strides.append(0)
# Add existing levels
new_levels.extend(levels)
# Add existing sizes and strides
orig_sizes = list(data.size())
orig_strides = list(data.stride())
new_sizes.extend(orig_sizes)
new_strides.extend(orig_strides)
# Create expanded tensor using as_strided
expanded_data = data.as_strided(new_sizes, new_strides, data.storage_offset())
# Return new tensor with expanded dimensions
result = Tensor.from_positional(expanded_data, new_levels, info.has_device)
return result # type: ignore[return-value] # Tensor and torch.Tensor are interchangeable
def index(
self,
dims: Union[int, Dim, tuple[Union[int, Dim], ...], list[Union[int, Dim]]],
indices: Union[
int,
slice,
torch.Tensor,
tuple[Union[int, slice, torch.Tensor], ...],
list[Union[int, slice, torch.Tensor]],
],
) -> _Tensor:
"""
Index tensor using first-class dimensions.
"""
from ._dim_entry import _match_levels
from ._getsetitem import getsetitem_flat, invoke_getitem
from ._wrap import _wrap_dim
# Helper to check if obj is a dimpack (tuple/list) and extract items
def maybe_dimpack(obj: Any, check_first: bool = False) -> tuple[Any, bool]:
if isinstance(obj, (tuple, list)):
return list(obj), True
return None, False
def parse_dim_entry(s: Any) -> Any:
d = _wrap_dim(s, self.ndim, False)
if d.is_none():
raise TypeError(f"expected a dimension specifyer but found {repr(s)}")
return d
# Helper for dimension not present errors
def dim_not_present(d: Any) -> None:
if d.is_positional():
raise TypeError(
f"dimension {d.position() + self.ndim} not in tensor of {self.ndim} dimensions"
)
else:
raise TypeError(f"dimension {repr(d.dim())} not in tensor")
dims_list: list[Union[int, Dim]] = []
indices_list: list[Union[int, slice, torch.Tensor]] = []
lhs_list = isinstance(dims, (tuple, list))
rhs_list = isinstance(indices, (tuple, list))
if lhs_list and rhs_list:
# Type narrowing: we know dims and indices are sequences here
dims_seq = dims # type: ignore[assignment]
indices_seq = indices # type: ignore[assignment]
if len(dims_seq) != len(indices_seq): # type: ignore[arg-type]
raise TypeError(
f"dims ({len(dims_seq)}) and indices ({len(indices_seq)}) must have the same length" # type: ignore[arg-type]
)
dims_list.extend(dims_seq) # type: ignore[arg-type]
indices_list.extend(indices_seq) # type: ignore[arg-type]
else:
dims_list.append(dims) # type: ignore[arg-type]
indices_list.append(indices) # type: ignore[arg-type]
# Create tensor info
self_info = TensorInfo.create(self, False, False)
new_levels: list[Any] = []
to_flatten: list[Any] = []
dims_list_flat = []
# Process each dim specification
for i in range(len(dims_list)):
m, is_dimpack = maybe_dimpack(dims_list[i], check_first=False)
if is_dimpack:
if len(m) == 0:
dims_list_flat.append(DimEntry()) # Empty dimpack
continue
first = parse_dim_entry(m[0])
dims_list_flat.append(first)
if len(m) == 1:
continue
# Multi-element dimpack requires flattening
if len(to_flatten) == 0:
new_levels.extend(self_info.levels)
rest = []
for j in range(1, len(m)):
d = parse_dim_entry(m[j])
removed = False
for k in range(len(new_levels)):
if new_levels[k] == d:
new_levels.pop(k)
removed = True
break
if not removed:
dim_not_present(d)
rest.append(d)
# Find first in new_levels
first_idx = None
for k in range(len(new_levels)):
if new_levels[k] == first:
first_idx = k
break
if first_idx is None:
dim_not_present(first)
continue # Skip this iteration if dimension not found
for j, r in enumerate(rest):
new_levels.insert(first_idx + 1 + j, r)
to_flatten.extend(rest)
else:
dims_list_flat.append(parse_dim_entry(dims_list[i]))
# Handle dimension flattening if needed
if len(to_flatten) > 0:
assert self_info.tensor is not None, (
"Cannot perform dimension flattening on None tensor"
)
rearranged = _match_levels(self_info.tensor, self_info.levels, new_levels)
sizes = rearranged.size()
new_sizes: list[Any] = []
reshape_levels = []
for i in range(len(new_levels)):
if new_levels[i] in to_flatten:
if len(new_sizes) == 0:
new_sizes.append(sizes[i])
else:
new_sizes[-1] *= sizes[i]
else:
new_sizes.append(sizes[i])
reshape_levels.append(new_levels[i])
self_info.tensor = rearranged.reshape(new_sizes)
self_info.levels = reshape_levels
# Check for dimpacks in indices
has_dimpacks = False
for idx in indices_list:
if isinstance(idx, (tuple, list)):
has_dimpacks = True
break
# Call getsetitem_flat with correct parameters
info = getsetitem_flat(
self_info,
[], # empty input_list
dims_list_flat, # keys
indices_list, # values
has_dimpacks,
)
return invoke_getitem(info)
def __repr__(self) -> str:
tensor, levels, ndim = self._get_tensor(), self._get_levels(), self.ndim
dims_repr = []
for l in levels:
if hasattr(l, "is_positional") and l.is_positional():
# Convert negative positional to positive: -1 -> ndim-1, -2 -> ndim-2, etc.
dims_repr.append(l.position() + ndim)
elif hasattr(l, "dim"):
dims_repr.append(l.dim())
elif hasattr(l, "data"):
dims_repr.append(l.data)
else:
dims_repr.append(l)
return f"{tensor}\nwith dims={tuple(dims_repr)} sizes={tuple(tensor.size())}" # type: ignore[union-attr]
TensorLike = (_Tensor, torch.Tensor)
class Dim(_Tensor):
_level: int
_name: str
_size: int
_range: Optional[torch.Tensor]
_batchtensor: Optional[torch.Tensor]
def __init__(self, name: str, s: int = -1) -> None:
global _n_dims_created
self._name = name
self._size = s
self._level = _n_dims_created
_n_dims_created += 1
self._range = None
self._batchtensor = None
@property
def ndim(self) -> int:
return 1
@classmethod
def check_exact(cls, obj: Any) -> bool:
return type(obj) is cls
@property
def size(self) -> int:
if self._size == -1:
raise ValueError(f"dimension {self._name} is unbound")
return self._size
@size.setter
def size(self, v: int) -> None:
if self._size == -1:
self._size = v
elif self._size != v:
raise DimensionBindError(
f"Dim '{repr(self)}' previously bound to a dimension of size {self._size} "
f"cannot bind to a dimension of size {v}"
)
@property
def is_bound(self) -> bool:
"""Return True if this dimension is bound to a size."""
return self._size != -1
def _get_range(self) -> torch.Tensor:
"""
Get a tensor representing the range [0, size) for this dimension.
Returns:
A 1D tensor with values [0, 1, 2, ..., size-1]
"""
if self._range is None:
self._range = torch.arange(self.size)
return self._range
def _get_batchtensor(self) -> torch.Tensor:
"""
Get a batched tensor representation of this dimension.
Returns:
A batched tensor created from the range tensor
"""
if self._batchtensor is None:
self._batchtensor = torch._C._functorch._add_batch_dim(
self._get_range(), 0, self._level
)
return self._batchtensor
def __repr__(self) -> str:
"""String representation of a Dim object."""
return self._name
# note that Dim comes before tensor because we want the Dim API for things like size to take precedence.
# Tensor defines format, but we want to print Dims with special formatting
__format__ = object.__format__
# Somewhat confusingly, an FCD tensor is also called Tensor. This confusion
# is somewhat intentional, as FCD tensors are intended to be substitutable
# with regular Tensor (just with some positional dims hidden).
class Tensor(_Tensor):
_tensor: Optional[torch.Tensor]
_batchtensor: Optional[torch.Tensor]
_levels: list[DimEntry]
_has_device: bool
_delayed: Optional[Callable[[], torch.Tensor]]
_delayed_orig: Optional[Callable]
_delayed_args: Optional[tuple]
@property
def ndim(self) -> int:
return sum(1 if l.is_positional() else 0 for l in self._levels)
@classmethod
def check_exact(cls, other: Any) -> bool:
return type(other) is cls
@classmethod
def from_positional(
cls, tensor: torch.Tensor, levels: list[DimEntry], has_device: bool
) -> Union[_Tensor, torch.Tensor]:
"""
Create a functorch Tensor from a regular PyTorch tensor with specified dimension levels.
This is the primary way to create Tensor objects with first-class dimensions.
Args:
tensor: The underlying PyTorch tensor
levels: List of DimEntry objects specifying the dimension structure
has_device: Whether the tensor is on a device (not CPU)
Returns:
A new Tensor instance with the specified dimensions, or a regular torch.Tensor
if there are no named dimensions
"""
seen_dims = 0
last = 0
for i, l in enumerate(levels):
if l.is_positional():
# Validate consecutive positional dimensions
assert last == 0 or last + 1 == l.position(), (
f"Positional dimensions must be consecutive, got {last} then {l.position()}"
)
last = l.position()
else:
# This is a named dimension
seen_dims += 1
# Validate final positional dimension
assert last == 0 or last == -1, (
f"Final positional dimension must be 0 or -1, got {last}"
)
if not seen_dims:
return tensor
# Create Tensor object with proper level management
result = cls()
result._tensor = tensor
result._levels = levels
result._has_device = has_device
result._batchtensor = None # Will be created lazily if needed
result._delayed = None
result._delayed_orig = None
result._delayed_args = None
# Validate tensor dimensionality matches levels
assert tensor.dim() == len(levels), (
f"Tensor has {tensor.dim()} dimensions but {len(levels)} levels provided"
)
return result
@classmethod
def create_delayed(
cls, orig: Callable, args: tuple, levels: list[DimEntry], has_device: bool
) -> _Tensor:
"""
Create a delayed tensor that defers the operation until later.
"""
result = cls()
result._tensor = None # Will be computed when needed
result._levels = levels
result._has_device = has_device
result._batchtensor = None
result._delayed_orig = orig
result._delayed_args = args
# Create delayed evaluation function that unwraps Tensor objects
def evaluate_delayed() -> torch.Tensor:
unwrapped_args = []
for arg in args:
if hasattr(arg, "_get_tensor"):
unwrapped_args.append(arg._get_tensor())
else:
unwrapped_args.append(arg)
return orig(*unwrapped_args)
result._delayed = evaluate_delayed
return result
def _get_tensor(self) -> Optional[torch.Tensor]:
"""Get the underlying tensor, handling delayed operations if needed."""
if (
hasattr(self, "_delayed")
and self._delayed is not None
and self._tensor is None
):
# Execute the delayed operation
self._tensor = self._delayed()
# Clear delayed operation to avoid re-execution
self._delayed = None
self._delayed_orig = None
self._delayed_args = None
return self._tensor
def _get_levels(self) -> list[Any]:
"""Get the dimension levels."""
return self._levels
def _get_has_device(self) -> bool:
"""Get whether this tensor has device information."""
return self._has_device
def _get_batchtensor(self) -> Optional[torch.Tensor]:
"""Get the batched tensor representation, creating it lazily if needed."""
if self._batchtensor is None:
self._batchtensor = self._add_batch_dims(
self._get_tensor(), self._get_levels()
)
return self._batchtensor
def _add_batch_dims(
self, t: Optional[torch.Tensor], levels_: list[Any]
) -> Optional[torch.Tensor]:
levels = list(levels_)
while True:
min_real_index = -1
min_index = -1
min_value = float("inf") # INT_MAX equivalent
i = 0
r = 0
for r, l in enumerate(levels):
if not l.is_none():
if not l.is_positional() and l.dim()._level < min_value:
min_value = l.dim()._level
min_index = i
min_real_index = r
i += 1
if min_index == -1:
return t
assert t is not None
t = torch._C._functorch._add_batch_dim(t, min_index, int(min_value))
levels[min_real_index] = DimEntry()
return None
def order(self, *dims: Any) -> _Tensor:
"""Reorder the dimensions of this tensor."""
from ._order import order
result = order(self, *dims)
return result # type: ignore[return-value] # Tensor and torch.Tensor are interchangeable
def stack(tensors: Any, new_dim: Any, dim: int = 0) -> _Tensor:
"""
Stack tensors along a new dimension.
Args:
tensors: Sequence of tensors to stack
new_dim: The new Dim to create for stacking
dim: The dimension position to insert the new dimension (default: 0)
Returns:
Stacked tensor with the new dimension
"""
if not tensors:
raise ValueError("stack expects a non-empty sequence of tensors")
# Check if new_dim is a Dim object
if not isinstance(new_dim, Dim):
# Fall back to regular torch.stack
result = torch.stack(tensors, dim=dim)
return result # type: ignore[return-value]
# Collect all result_levels from input tensors
result_levels = []
infos = []
for t in tensors:
info = TensorInfo.create(t, ensure_batched=False, ensure_present=False)
infos.append(info)
for level in info.levels:
if level not in result_levels:
result_levels.append(level)
# Set the new_dim size to match number of tensors
new_dim.size = len(tensors)
# Match all tensors to the common level structure using _match_levels
inputs = []
for info in infos:
assert info.tensor is not None, "Cannot stack tensors with None tensor data"
matched_tensor = _match_levels(info.tensor, info.levels, result_levels)
inputs.append(matched_tensor)
# Calculate ndim and resolve the dim parameter
ndim = ndim_of_levels(result_levels)
rawdim = 0
if dim is not None and not (isinstance(dim, int) and dim == 0):
from ._wrap import _wrap_dim
d = _wrap_dim(dim, ndim, False)
try:
idx = result_levels.index(d)
except ValueError:
raise TypeError(f"Dimension {dim} does not exist in inputs") from None
rawdim = idx
# Stack tensors at the resolved dimension
result = torch.stack(inputs, rawdim)
# Insert new dimension entry at the correct position
result_levels.insert(rawdim, DimEntry(new_dim))
# Return as a first-class tensor
tensor_result = Tensor.from_positional(
result, result_levels, infos[0].has_device if infos else True
)
return tensor_result # type: ignore[return-value]
def split(tensor: Any, split_size_or_sections: Any, dim: Any = None) -> tuple:
"""
Split tensor along a dimension.
Can handle both regular integer sizes and Dim objects for split sizes.
When Dim objects are used, they get bound to the resulting tensor dimensions.
"""
from ._wrap import _wrap_dim
# Check if dim is a Dim object
dim_is_object = isinstance(dim, Dim)
# Parse split_size_or_sections
if isinstance(split_size_or_sections, int):
# Single integer - use regular split
if dim_is_object:
raise TypeError(
"when dim is specified as a Dim object, split sizes must also be dimensions."
)
return _Tensor._torch_function_fallback(
torch.Tensor.split,
(type(tensor),),
(tensor, split_size_or_sections),
{"dim": dim},
)
# Check if it's a sequence
sizes = []
all_dims = True
all_ints = True
for item in split_size_or_sections:
sizes.append(item)
if isinstance(item, Dim):
all_ints = False
else:
all_dims = False
if all_ints:
# All integers - use regular split
if dim_is_object:
raise TypeError(
"when dim is specified as a Dim object, split sizes must also be dimensions."
)
return _Tensor._torch_function_fallback(
torch.Tensor.split,
(type(tensor),),
(tensor, split_size_or_sections),
{"dim": dim},
)
if not all_dims:
raise TypeError("split list must be ints or dims but got a mix")
# All are Dim objects - handle first-class dimension split
self_info = TensorInfo.create(tensor, ensure_batched=False, ensure_present=False)
ndim = self_info.ndim()
if not dim_is_object and ndim == 0:
raise TypeError("split expects at least a 1-dimension tensor")
# Wrap the dimension
dim_l = _wrap_dim(dim, ndim, False) if dim is not None else DimEntry(-ndim)
# Find the index of the dimension in levels
idx = None
for i, level in enumerate(self_info.levels):
if level == dim_l:
idx = i
break
if idx is None:
if dim is None:
dim = 0
raise TypeError(f"tensor does not contain dimension {dim}")
# Calculate split indices
indices = []
total_size = 0
unbound = []
for i, size_dim in enumerate(sizes):
if size_dim.is_bound:
indices.append(size_dim.size)
total_size += indices[-1]
else:
indices.append(0)
unbound.append(i)
assert self_info.tensor is not None, "Cannot get tensor size on None tensor"
tensor_size = self_info.tensor.size(idx)
# Handle unbound dimensions
if unbound:
if total_size > tensor_size:
raise TypeError(
f"sizes of target dimensions add up to more ({total_size}) than source dim ({tensor_size})"
)
remaining_size = tensor_size - total_size
chunk_size = (remaining_size + len(unbound) - 1) // len(unbound)
for u in unbound:
sz = min(chunk_size, remaining_size)
sizes[u].size = sz
indices[u] = sz
remaining_size -= sz
elif tensor_size != total_size:
raise TypeError(
f"sum of sizes of target dimensions ({total_size}) do not match the source dim ({tensor_size})"
)
# Perform the split
result_tensors = self_info.tensor.split_with_sizes(indices, idx)
# Create result with new levels
result = []
new_levels = list(self_info.levels)
for i, (result_tensor, size_dim) in enumerate(zip(result_tensors, sizes)):
new_levels[idx] = DimEntry(size_dim)
result.append(
Tensor.from_positional(
result_tensor, list(new_levels), self_info.has_device
)
)
return tuple(result)
def cat(tensors: Any, dim: Any, new_dim: Any) -> _Tensor:
n = dims(1) # Get single Dim instead of tuple
return stack(tensors, n, dim).index([n, dim], new_dim) # type: ignore[list-item]
class DotPart:
"""
Helper class for organizing dimensions in dot products.
"""
def __init__(self) -> None:
self.dims: list[DimEntry] = []
self.total_size = 1
def append(self, dim_entry: Any) -> None:
"""Add a dimension entry to this part."""
self.dims.append(dim_entry)
if not dim_entry.is_positional():
self.total_size *= dim_entry.dim().size
def dot_prepare(parts: list[DotPart], tensor_info: TensorInfo) -> torch.Tensor:
"""
Prepare tensor for dot product by matching levels and reshaping.
"""
new_levels = []
needs_reshape = False
for part in parts:
if len(part.dims) != 1:
needs_reshape = True
new_levels.extend(part.dims)
if tensor_info.tensor is None:
raise RuntimeError("Cannot perform dot product on None tensor")
result = _match_levels(tensor_info.tensor, tensor_info.levels, new_levels)
if not needs_reshape:
return result
# Reshape for matrix operations
view = [part.total_size for part in parts]
return result.reshape(view)
def dot_finish(parts: list[DotPart], result_tensor: torch.Tensor) -> Tensor:
"""
Finish dot product by reshaping result and creating Tensor.
"""
result_levels = []
needs_reshape = False
for part in parts:
if len(part.dims) != 1:
needs_reshape = True
result_levels.extend(part.dims)
if needs_reshape:
new_size = []
for level in result_levels:
new_size.append(level.dim().size)
result_tensor = result_tensor.reshape(new_size)
tensor_result = Tensor.from_positional(result_tensor, result_levels, True)
return tensor_result # type: ignore[return-value]
def dot(lhs: Any, rhs: Any, sum_dims: Any) -> Union[_Tensor, torch.Tensor]:
"""
Perform dot product between two tensors along specified dimensions.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
sum_dims: Dimensions to sum over (contract)
Returns:
Result of dot product
"""
# Get tensor info
lhs_info = TensorInfo.create(lhs, ensure_batched=False, ensure_present=False)
rhs_info = TensorInfo.create(rhs, ensure_batched=False, ensure_present=False)
if not (lhs_info and rhs_info):
# Fall back to regular operations
return torch.matmul(lhs, rhs)
assert lhs_info.tensor is not None and rhs_info.tensor is not None, (
"Cannot perform dot product on None tensors"
)
lhs_strides = lhs_info.tensor.stride()
rhs_strides = rhs_info.tensor.stride()
# Create dot parts for different dimension categories
lro_dims = DotPart() # Left-right-output (batch dims)
lo_dims = DotPart() # Left-output only
ro_dims = DotPart() # Right-output only
lr_dims = DotPart() # Left-right (contracted dims)
def insert_dim(d: Any, lhs_idx: Any, rhs_idx: Any) -> None:
"""Insert dimension into appropriate part based on stride pattern."""
reduced = d in sum_dims
lhs_stride = lhs_strides[lhs_idx] if lhs_idx is not None else 0
rhs_stride = rhs_strides[rhs_idx] if rhs_idx is not None else 0
if reduced:
lr_dims.append(d)
else:
if (lhs_stride == 0) == (rhs_stride == 0):
lro_dims.append(d) # Both have or both lack this dim
elif lhs_stride != 0:
lo_dims.append(d) # Only lhs has this dim
else:
ro_dims.append(d) # Only rhs has this dim
# Track which rhs dimensions we've seen
rhs_seen = [False] * len(rhs_info.levels)
# Process lhs dimensions
for i, lhs_level in enumerate(lhs_info.levels):
rhs_idx = None
for j, rhs_level in enumerate(rhs_info.levels):
if lhs_level == rhs_level:
rhs_idx = j
rhs_seen[j] = True
break
insert_dim(lhs_level, i, rhs_idx)
# Process remaining rhs dimensions
for i, rhs_level in enumerate(rhs_info.levels):
if not rhs_seen[i]:
insert_dim(rhs_level, None, i)
# Validate sum dimensions exist
if len(lr_dims.dims) != len(sum_dims):
for d in sum_dims:
if d not in lhs_info.levels and d not in rhs_info.levels:
raise ValueError(f"summing over non-existent dimension {d}")
# Prepare tensors and perform matrix multiplication
if len(lro_dims.dims) != 0:
# Batched matrix multiply
lhs_tensor = dot_prepare([lro_dims, lo_dims, lr_dims], lhs_info)
rhs_tensor = dot_prepare([lro_dims, lr_dims, ro_dims], rhs_info)
result = torch.bmm(lhs_tensor, rhs_tensor)
return dot_finish([lro_dims, lo_dims, ro_dims], result)
else:
# Regular matrix multiply
lhs_tensor = dot_prepare([lo_dims, lr_dims], lhs_info)
rhs_tensor = dot_prepare([lr_dims, ro_dims], rhs_info)
result = torch.mm(lhs_tensor, rhs_tensor)
return dot_finish([lo_dims, ro_dims], result)
from functorch.dim._wrap import _wrap
from functorch.dim.wrap_type import wrap_type
wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__)
del _Tensor.ndim
def index(self: Any, positions: Any, dims: Any) -> _Tensor:
"""
Index a regular tensor by binding specified positions to dims.
This converts a regular tensor to a first-class tensor by binding
the specified positional dimensions to Dim objects.
Args:
positions: Tuple of dimension positions to bind
dims: Dim objects or tuple of Dim objects to bind to
Returns:
First-class tensor with specified dimensions bound
"""
# If this is already a first-class tensor (_Tensor), call its index method directly
if isinstance(self, _Tensor):
return _Tensor.index(self, positions, dims)
# Convert regular tensor to first-class tensor
info = TensorInfo.create(self, ensure_batched=False, ensure_present=False)
# Create the first-class tensor
assert info.tensor is not None, "Cannot index None tensor"
result = Tensor.from_positional(info.tensor, info.levels, info.has_device)
# Now call the index method on the first-class tensor
# Cast result to _Tensor for the method call
return _Tensor.index(result, positions, dims) # type: ignore[arg-type]
def _def(name: str, *args: Any, **kwargs: Any) -> None:
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
_def("mean")
_def("sum")
_def("all")
_def("amax")
_def("amin")
_def("aminmax")
_def("any")
_def("count_nonzero")
_def("logsumexp")
_def("nanmean")
_def("nansum")
_def("prod")
_def("std", keepdim_offset=2)
_def("var", keepdim_offset=2)
_def("max", single_dim=True)
_def("min", single_dim=True)
_def("argmax", single_dim=True)
_def("argmin", single_dim=True)
_def("kthvalue", single_dim=True)
_def("median", single_dim=True)
_def("nanmedian", single_dim=True)
_def("mode", single_dim=True)
_def("sort", reduce=False)
_def("argsort", reduce=False)
_def("unbind", single_dim=True)
_def("chunk", dim_offset=1, reduce=False)
_def("cummax", single_dim=True, reduce=False)
_def("cummin", single_dim=True, reduce=False)
_def("cumprod", single_dim=True, reduce=False)
_def("cumprod_", single_dim=True, reduce=False)
_def("cumsum", single_dim=True, reduce=False)
_def("cumsum_", single_dim=True, reduce=False)
_def("logcumsumexp", single_dim=True, reduce=False)
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
_def("softmax", single_dim=True, reduce=False)
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
# stuff to handle in the future, because they require special
# binding logic for dims
# cross
# diag_embed
# diagonal
# diagonal_scatter
# diff
# nanquantile
# quantile
# roll
# rot90
# topk (new dimes on output)
# should these all be subsumed by inplace indexing?
# index_add_
# index_add
# index_copy
# index_copy_
# index_fill
# index_fill_
# index_select
# scatter
# scatter_
# scatter_add
# scatter_add_
# scatter_reduce