mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is follow-up of #165214 to continue applying ruff UP035 rule to the code base. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165515 Approved by: https://github.com/Lucaskabela
1593 lines
52 KiB
Python
1593 lines
52 KiB
Python
from __future__ import annotations
|
|
|
|
import dis
|
|
import inspect
|
|
import sys
|
|
from typing import Any, Optional, TYPE_CHECKING, Union
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Callable, Sequence
|
|
|
|
import torch
|
|
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
|
|
|
from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
|
|
from ._enable_all_layers import EnableAllLayers
|
|
from ._py_inst_decoder import _PyInstDecoder
|
|
from ._tensor_info import TensorInfo
|
|
|
|
|
|
POINTWISE_OPTIMIZE = True
|
|
DOT_OPTIMIZED = True
|
|
|
|
# Global dimension level counter
|
|
_n_dims_created = 0
|
|
|
|
|
|
def _relevant_op(opcode: Optional[str]) -> bool:
|
|
"""Check if opcode is relevant for variable assignment."""
|
|
return bool(opcode and opcode.startswith("STORE_"))
|
|
|
|
|
|
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
|
"""Handle tensor conversion for torch function integration."""
|
|
return tensor
|
|
|
|
|
|
def _create_dim(name: str, size: Optional[int] = None) -> Dim:
|
|
"""Create a new Dim object."""
|
|
return Dim(name, size if size is not None else -1)
|
|
|
|
|
|
def dims(
|
|
n: Optional[int] = None, sizes: Optional[list[Optional[int]]] = None
|
|
) -> Union[Dim, tuple[Dim, ...]]:
|
|
"""
|
|
Create and return one or more Dim objects.
|
|
|
|
Uses bytecode inspection to determine variable names when possible.
|
|
|
|
Args:
|
|
n (int, optional): The number of dimensions to create. Can be omitted if sizes is specified.
|
|
sizes (List[Optional[int]], optional): A list the same size as the number of dimensions to be
|
|
created, specifying each dimensions size, or None to leave the size unset.
|
|
|
|
Returns:
|
|
Union[Dim, Tuple[Dim, ...]]: Single Dim if n=1, tuple of Dims otherwise.
|
|
|
|
Examples:
|
|
>>> batch, channel, width, height = dims(4)
|
|
>>> batch, channel, width, height = dims(sizes=[None, 3, 224, 224])
|
|
>>> single_dim = dims(1)
|
|
"""
|
|
specified_ndims = -1
|
|
found_ndims = 0
|
|
|
|
# Parse arguments
|
|
if sizes is not None:
|
|
specified_ndims = len(sizes)
|
|
if n is not None:
|
|
specified_ndims = n
|
|
|
|
# Use bytecode inspection
|
|
frame = inspect.currentframe()
|
|
if frame is None:
|
|
raise RuntimeError("Unable to get current frame")
|
|
frame = frame.f_back
|
|
try:
|
|
if frame is None:
|
|
raise RuntimeError("Unable to get caller frame")
|
|
code = frame.f_code
|
|
lasti = frame.f_lasti
|
|
|
|
decoder = _PyInstDecoder(code, lasti)
|
|
|
|
if sys.version_info >= (3, 11):
|
|
if decoder.opcode() == "PRECALL":
|
|
decoder.next()
|
|
|
|
# Move to next instruction after the call
|
|
decoder.next()
|
|
|
|
# Determine number of dimensions from bytecode
|
|
if _relevant_op(decoder.opcode()):
|
|
found_ndims = 1
|
|
elif decoder.opcode() == "UNPACK_SEQUENCE":
|
|
found_ndims = decoder.oparg()
|
|
decoder.next() # Move past UNPACK_SEQUENCE
|
|
|
|
if specified_ndims == -1:
|
|
if found_ndims == 0:
|
|
raise SyntaxError(
|
|
"dims() must be assigned to a sequence of variable names or have argument n specified"
|
|
)
|
|
specified_ndims = found_ndims
|
|
|
|
if found_ndims != specified_ndims:
|
|
found_ndims = 0
|
|
|
|
def genobject(i: int) -> Dim:
|
|
nonlocal found_ndims
|
|
name = None
|
|
if i < found_ndims:
|
|
name = decoder.name()
|
|
|
|
if not name:
|
|
name = f"d{i}"
|
|
found_ndims = 0
|
|
else:
|
|
decoder.next() # Move to next STORE instruction
|
|
|
|
size = sizes[i] if sizes is not None else None
|
|
return _create_dim(name, size)
|
|
|
|
# Validate sizes parameter
|
|
if sizes is not None and len(sizes) != specified_ndims:
|
|
raise ValueError(f"expected {specified_ndims} sizes but found {len(sizes)}")
|
|
|
|
if specified_ndims == 1:
|
|
return genobject(0)
|
|
|
|
result = []
|
|
for i in range(specified_ndims):
|
|
result.append(genobject(i))
|
|
|
|
return tuple(result)
|
|
|
|
finally:
|
|
del frame
|
|
|
|
|
|
class DimList:
|
|
"""
|
|
A list of first-class dimensions that can be bound to tensor dimensions.
|
|
|
|
A DimList can be in one of two states:
|
|
1. Unbound: Created with just a name, no specific dimensions yet
|
|
2. Bound: Either created with specific dimensions/sizes, or bound later via bind() or bind_len()
|
|
"""
|
|
|
|
_name: Optional[str]
|
|
_dims: list[Dim]
|
|
_bound: bool
|
|
|
|
def __init__(
|
|
self,
|
|
len_or_dims: Optional[Union[int, Sequence]] = None,
|
|
name: Optional[str] = None,
|
|
):
|
|
"""
|
|
Initialize a new DimList object.
|
|
|
|
Args:
|
|
len_or_dims: Optional length (int) or sequence of dimensions/sizes
|
|
name: Optional name for the dimension list
|
|
"""
|
|
# Initialize attributes
|
|
self._name = name
|
|
self._dims: list = []
|
|
self._bound = False
|
|
|
|
if isinstance(len_or_dims, int):
|
|
self.bind_len(len_or_dims)
|
|
elif len_or_dims is not None:
|
|
dims = []
|
|
for i, item in enumerate(len_or_dims):
|
|
if isinstance(item, int):
|
|
dim_name = f"{self._name}{i}" if self._name else f"dim{i}"
|
|
dims.append(Dim(dim_name, item))
|
|
else:
|
|
dims.append(Dim(item))
|
|
self._set_dims(dims)
|
|
|
|
def _set_dims(self, dims: list) -> None:
|
|
"""Set the dimensions and mark as bound."""
|
|
self._bound = True
|
|
self._dims = dims
|
|
|
|
def bind_len(self, size: int) -> None:
|
|
"""
|
|
Bind this DimList to a specific length.
|
|
|
|
Args:
|
|
size: Number of dimensions to bind to
|
|
|
|
Raises:
|
|
DimensionBindError: If already bound to a different size
|
|
"""
|
|
if self._bound:
|
|
if len(self._dims) != size:
|
|
raise DimensionBindError(
|
|
f"Dimlist has size {len(self._dims)} but it is being bound to size {size}"
|
|
)
|
|
else:
|
|
self._bound = True
|
|
self._dims = []
|
|
for i in range(size):
|
|
dim_name = f"{self._name}{i}" if self._name else f"dim{i}"
|
|
self._dims.append(Dim(dim_name))
|
|
|
|
def bind(self, sizes: Sequence[int]) -> None:
|
|
"""
|
|
Bind this DimList to specific sizes.
|
|
|
|
Args:
|
|
sizes: Sequence of sizes for each dimension
|
|
|
|
Raises:
|
|
ValueError: If sizes is not a sequence
|
|
"""
|
|
if not hasattr(sizes, "__len__") or not hasattr(sizes, "__getitem__"):
|
|
raise ValueError("expected a sequence")
|
|
|
|
size = len(sizes)
|
|
self.bind_len(size)
|
|
|
|
for i, dim_size in enumerate(sizes):
|
|
self._dims[i].size = int(dim_size)
|
|
|
|
def _size(self) -> int:
|
|
if not self._bound:
|
|
raise DimensionBindError("DimList not bound")
|
|
return len(self._dims)
|
|
|
|
def size(self) -> int:
|
|
"""Return the size (number of dimensions) of this DimList."""
|
|
return self._size()
|
|
|
|
def _set_bound(self, b: bool) -> None:
|
|
"""Set the bound status (for internal use)."""
|
|
self._bound = b
|
|
|
|
@property
|
|
def is_bound(self) -> bool:
|
|
"""Property to check if DimList is bound."""
|
|
return self._bound
|
|
|
|
def __len__(self) -> int:
|
|
"""Return the length of the DimList."""
|
|
return self.size()
|
|
|
|
def __getitem__(self, key: Union[int, slice]) -> Union[Dim, tuple[Dim, ...]]:
|
|
if not self._bound:
|
|
raise DimensionBindError("DimList not bound")
|
|
|
|
if isinstance(key, int):
|
|
if key < 0 or key >= len(self._dims):
|
|
raise IndexError("index out of bounds")
|
|
return self._dims[key]
|
|
elif isinstance(key, slice):
|
|
start, stop, step = key.indices(len(self._dims))
|
|
result = []
|
|
for i in range(start, stop, step):
|
|
result.append(self._dims[i])
|
|
return tuple(result)
|
|
else:
|
|
raise ValueError("expected an int or a slice")
|
|
|
|
def __repr__(self) -> str:
|
|
"""Return string representation of the DimList."""
|
|
if self._bound:
|
|
# Show as tuple representation
|
|
return f"({', '.join(repr(dim) for dim in self._dims)})"
|
|
elif self._name is not None:
|
|
# Show as *name for unbound with name
|
|
return f"*{self._name}"
|
|
else:
|
|
# Show as <unbound_dimlist> for unbound without name
|
|
return "<unbound_dimlist>"
|
|
|
|
def __str__(self) -> str:
|
|
"""Return string representation of the DimList."""
|
|
return self.__repr__()
|
|
|
|
@classmethod
|
|
def __torch_function__(
|
|
cls,
|
|
func: Callable,
|
|
types: tuple,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict] = None,
|
|
) -> Any:
|
|
return _Tensor.__torch_function__(func, types, args, kwargs)
|
|
|
|
|
|
def _create_dimlist(
|
|
name: str, size: Optional[Union[int, list[Optional[int]]]] = None
|
|
) -> DimList:
|
|
"""Create a DimList object with the given name and optional size."""
|
|
dimlist = DimList(name=name)
|
|
if size is not None:
|
|
if isinstance(size, int):
|
|
dimlist.bind_len(size)
|
|
else:
|
|
# size is a list of optional ints
|
|
dimlist.bind_len(len(size))
|
|
for i, s in enumerate(size):
|
|
if s is not None:
|
|
dimlist._dims[i].size = s
|
|
return dimlist
|
|
|
|
|
|
def dimlists(
|
|
n: Optional[int] = None, sizes: Optional[list[Optional[int]]] = None
|
|
) -> Union[DimList, tuple[DimList, ...]]:
|
|
"""
|
|
Create and return one or more DimList objects.
|
|
|
|
Similar to dims() but creates DimList objects instead.
|
|
"""
|
|
specified_ndims = -1
|
|
found_ndims = 0
|
|
|
|
# Parse arguments
|
|
if sizes is not None:
|
|
specified_ndims = len(sizes)
|
|
if n is not None:
|
|
specified_ndims = n
|
|
|
|
frame = inspect.currentframe()
|
|
if frame is None:
|
|
raise RuntimeError("Unable to get current frame")
|
|
frame = frame.f_back
|
|
try:
|
|
if frame is None:
|
|
raise RuntimeError("Unable to get caller frame")
|
|
code = frame.f_code
|
|
lasti = frame.f_lasti
|
|
|
|
decoder = _PyInstDecoder(code, lasti)
|
|
|
|
if sys.version_info >= (3, 11):
|
|
if decoder.opcode() == "PRECALL":
|
|
decoder.next()
|
|
|
|
# Move to next instruction after the call
|
|
decoder.next()
|
|
|
|
# Determine number of dimensions from bytecode
|
|
if _relevant_op(decoder.opcode()):
|
|
found_ndims = 1
|
|
elif decoder.opcode() == "UNPACK_SEQUENCE":
|
|
found_ndims = decoder.oparg()
|
|
decoder.next() # Move past UNPACK_SEQUENCE
|
|
|
|
if specified_ndims == -1:
|
|
if found_ndims == 0:
|
|
raise SyntaxError(
|
|
"dimlists() must be assigned to a sequence of variable names or have argument n specified"
|
|
)
|
|
specified_ndims = found_ndims
|
|
|
|
if found_ndims != specified_ndims:
|
|
found_ndims = 0
|
|
|
|
# Generator function for dimlist names
|
|
def genobject(i: int) -> str:
|
|
nonlocal found_ndims
|
|
name = None
|
|
if i < found_ndims:
|
|
name = decoder.name()
|
|
|
|
if not name:
|
|
name = f"d{i}"
|
|
found_ndims = 0
|
|
else:
|
|
decoder.next() # Move to next STORE instruction
|
|
|
|
return name
|
|
|
|
# Validate sizes
|
|
if sizes is not None and len(sizes) != specified_ndims:
|
|
raise ValueError(f"expected {specified_ndims} sizes but found {len(sizes)}")
|
|
|
|
# Create dimlists
|
|
if specified_ndims == 1:
|
|
name = genobject(0)
|
|
return _create_dimlist(name, sizes[0] if sizes is not None else None)
|
|
|
|
result = []
|
|
for i in range(specified_ndims):
|
|
name = genobject(i)
|
|
size = sizes[i] if sizes is not None else None
|
|
result.append(_create_dimlist(name, size))
|
|
|
|
return tuple(result)
|
|
|
|
finally:
|
|
del frame
|
|
|
|
|
|
class DimensionMismatchError(Exception):
|
|
pass
|
|
|
|
|
|
class DimensionBindError(Exception):
|
|
pass
|
|
|
|
|
|
from . import op_properties
|
|
|
|
|
|
def _safe_print(*args: Any, **kwargs: Any) -> None:
|
|
"""Safe print that avoids recursive torch function dispatches."""
|
|
import sys
|
|
|
|
# Convert any torch objects to basic representations
|
|
safe_args = []
|
|
for arg in args:
|
|
if hasattr(arg, "__class__") and "torch" in str(type(arg)):
|
|
safe_args.append(f"<{type(arg).__name__}>")
|
|
else:
|
|
safe_args.append(str(arg))
|
|
|
|
print(*safe_args, **kwargs, file=sys.stderr)
|
|
|
|
|
|
class _Tensor:
|
|
def _get_levels(self) -> list[Any]:
|
|
raise NotImplementedError("_get_levels must be implemented by subclass")
|
|
|
|
def _get_tensor(self) -> Optional[torch.Tensor]:
|
|
raise NotImplementedError("_get_tensor must be implemented by subclass")
|
|
|
|
@property
|
|
def ndim(self) -> int:
|
|
raise NotImplementedError("ndim must be implemented by subclass")
|
|
|
|
@property
|
|
def dims(self) -> tuple[Any, ...]:
|
|
return tuple(l.dim() for l in self._get_levels() if not l.is_positional())
|
|
|
|
def dim(self) -> int:
|
|
return self.ndim
|
|
|
|
@classmethod
|
|
def __torch_function__(
|
|
cls,
|
|
func: Callable,
|
|
types: tuple,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict] = None,
|
|
) -> Any:
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
if DOT_OPTIMIZED and func is torch.Tensor.__mul__:
|
|
# Check conditions: 2 args, both are tensor-like, both 0-dimensional
|
|
if (
|
|
len(args) == 2
|
|
and not kwargs
|
|
and isinstance(args[0], (_Tensor, torch.Tensor))
|
|
and isinstance(args[1], (_Tensor, torch.Tensor))
|
|
):
|
|
# Get tensor info for both operands
|
|
lhs_info = TensorInfo.create(
|
|
args[0], ensure_batched=False, ensure_present=False
|
|
)
|
|
rhs_info = TensorInfo.create(
|
|
args[1], ensure_batched=False, ensure_present=False
|
|
)
|
|
|
|
if (
|
|
lhs_info
|
|
and rhs_info
|
|
and lhs_info.tensor is not None
|
|
and rhs_info.tensor is not None
|
|
and lhs_info.tensor.dim() == 0
|
|
and rhs_info.tensor.dim() == 0
|
|
):
|
|
if (
|
|
lhs_info.tensor.is_floating_point()
|
|
and rhs_info.tensor.is_floating_point()
|
|
):
|
|
# Collect all unique levels and has_device
|
|
has_device = lhs_info.has_device or rhs_info.has_device
|
|
levels = []
|
|
|
|
for level in lhs_info.levels:
|
|
if level not in levels:
|
|
levels.append(level)
|
|
for level in rhs_info.levels:
|
|
if level not in levels:
|
|
levels.append(level)
|
|
|
|
# Debug print
|
|
# print(f"DEBUG: Creating delayed mul, levels: {levels}, has_device: {has_device}")
|
|
|
|
# Create delayed tensor
|
|
return Tensor.create_delayed(func, args, levels, has_device)
|
|
|
|
if func is torch.Tensor.__getitem__:
|
|
from functorch.dim._getsetitem import getitem
|
|
|
|
return getitem(cls, func, types, args, kwargs)
|
|
|
|
if func is torch.Tensor.__setitem__:
|
|
from functorch.dim._getsetitem import setitem
|
|
|
|
# args should be (tensor, index, value)
|
|
if len(args) == 3:
|
|
setitem(args[0], args[1], args[2])
|
|
return None
|
|
else:
|
|
raise ValueError(f"Expected 3 args for __setitem__, got {len(args)}")
|
|
|
|
# Fast-path for len; mostly to avoid infinite loop in TestMinFunctorchOnly.test_softmax_split
|
|
if func is torch.Tensor.__len__:
|
|
return args[0].size(0)
|
|
|
|
# Special handling for torch.softmax - use the pre-wrapped version
|
|
if func is torch.softmax:
|
|
return softmax(*args, **kwargs)
|
|
|
|
# Special handling for torch.stack - use the custom stack function
|
|
if func is torch.stack:
|
|
return stack(*args, **kwargs)
|
|
|
|
if (
|
|
func is torch.Tensor.split
|
|
or func is torch._VF.split # type: ignore[attr-defined]
|
|
or func is torch._VF.split_with_sizes # type: ignore[attr-defined]
|
|
or func is torch.split
|
|
):
|
|
return split(*args, **kwargs)
|
|
|
|
return _Tensor._torch_function_fallback(func, types, args, kwargs)
|
|
|
|
@staticmethod
|
|
def _torch_function_fallback(
|
|
func: Callable, types: tuple, args: tuple, kwargs: dict
|
|
) -> Any:
|
|
"""Fallback torch function implementation for non-special-cased functions."""
|
|
is_pointwise = POINTWISE_OPTIMIZE and func in op_properties.pointwise
|
|
# TODO: optimize pytree here
|
|
flat_args, spec = tree_flatten((args, kwargs))
|
|
device_holding_tensor = None
|
|
|
|
infos: list[TensorInfo] = []
|
|
result_levels: list[DimEntry] = []
|
|
|
|
for f in flat_args:
|
|
info = TensorInfo.create(f, not is_pointwise, False)
|
|
infos.append(info)
|
|
if info:
|
|
assert is_pointwise or info.batchedtensor is not None
|
|
if device_holding_tensor is None and info.has_device:
|
|
device_holding_tensor = info.tensor
|
|
# Collect all unique levels
|
|
for level in info.levels:
|
|
assert isinstance(level, DimEntry)
|
|
if level not in result_levels:
|
|
result_levels.append(level)
|
|
|
|
if is_pointwise:
|
|
# Pointwise operation: match all tensors to common levels
|
|
for i, info in enumerate(infos):
|
|
if info and info.tensor is not None:
|
|
tensor = info.tensor
|
|
if device_holding_tensor is not None and not info.has_device:
|
|
tensor = tensor.to(device_holding_tensor.device)
|
|
ml = _match_levels(tensor, info.levels, result_levels)
|
|
flat_args[i] = handle_from_tensor(ml)
|
|
|
|
unflat_args, unflat_kwargs = tree_unflatten(flat_args, spec)
|
|
result = func(*unflat_args, **unflat_kwargs)
|
|
|
|
# Wrap tensor results
|
|
def wrap_tensor(obj: Any) -> Any:
|
|
if isinstance(obj, torch.Tensor):
|
|
return Tensor.from_positional(
|
|
obj, result_levels, device_holding_tensor is not None
|
|
)
|
|
return obj
|
|
|
|
# Small fastpath
|
|
if isinstance(result, torch.Tensor):
|
|
return wrap_tensor(result)
|
|
else:
|
|
return tree_map(wrap_tensor, result)
|
|
|
|
# Non-pointwise operation: use functorch vmap layers
|
|
with EnableAllLayers(result_levels) as guard:
|
|
# Update arguments with batched tensors
|
|
for i, info in enumerate(infos):
|
|
if info and info.batchedtensor is not None:
|
|
batched = info.batchedtensor
|
|
if device_holding_tensor is not None and not info.has_device:
|
|
batched = batched.to(device_holding_tensor.device)
|
|
guard.inplace_update_layers(batched, info.levels)
|
|
flat_args[i] = handle_from_tensor(batched)
|
|
|
|
unflat_args, unflat_kwargs = tree_unflatten(flat_args, spec)
|
|
result = func(*unflat_args, **unflat_kwargs)
|
|
|
|
# Unwrap results from functorch layers
|
|
def unwrap_tensor(obj: Any) -> Any:
|
|
if isinstance(obj, torch.Tensor):
|
|
return guard.from_batched(obj, device_holding_tensor is not None)
|
|
return obj
|
|
|
|
if isinstance(result, torch.Tensor):
|
|
return unwrap_tensor(result)
|
|
else:
|
|
return tree_map(unwrap_tensor, result)
|
|
|
|
def __setitem__(self, index: Any, value: Any) -> None:
|
|
"""Set values in tensor using first-class dimensions."""
|
|
from functorch.dim._getsetitem import setitem
|
|
|
|
return setitem(self, index, value)
|
|
|
|
# expand and index are OK to be methods because they don't have torch.*
|
|
# versions, but if they did they need the stack/cat treatment
|
|
|
|
def expand(self, *args: Dim) -> _Tensor:
|
|
"""
|
|
Expand tensor by adding new dimensions or expanding existing dimensions.
|
|
|
|
If all arguments are Dim objects, adds new named dimensions.
|
|
Otherwise, falls back to regular tensor expansion behavior.
|
|
|
|
Args:
|
|
args: Either Dim objects for new dimensions or sizes for regular expansion
|
|
|
|
Returns:
|
|
New tensor with expanded dimensions
|
|
|
|
Example:
|
|
>>> i, j = dims()
|
|
>>> t = torch.randn(3, 4)
|
|
>>> expanded = t[i].expand(j, k) # Add j, k dimensions
|
|
>>> expanded2 = t[i].expand(2, 4) # Regular expand with sizes
|
|
"""
|
|
info = TensorInfo.create(self, ensure_batched=False, ensure_present=False)
|
|
|
|
for arg in args:
|
|
if not isinstance(arg, Dim):
|
|
# Not all args are Dims, fallback to regular expand
|
|
if isinstance(self, torch.Tensor) and not isinstance(self, _Tensor):
|
|
return torch.Tensor.expand(self, *args)
|
|
else:
|
|
return self.__torch_function__(
|
|
torch.Tensor.expand, (type(self),), (self,) + args
|
|
)
|
|
|
|
# All args are Dim objects - proceed with first-class dimension expansion
|
|
if not info:
|
|
# No tensor info available, fallback
|
|
return self.__torch_function__(
|
|
torch.Tensor.expand, (type(self),), (self,) + args
|
|
)
|
|
|
|
# First-class dimension expansion - all args are Dim objects
|
|
data = info.tensor
|
|
if data is None:
|
|
# No tensor data available, fallback
|
|
return self.__torch_function__(
|
|
torch.Tensor.expand, (type(self),), (self,) + args
|
|
)
|
|
|
|
levels = info.levels
|
|
|
|
new_levels: list[DimEntry] = []
|
|
new_sizes = []
|
|
new_strides = []
|
|
|
|
for d in args:
|
|
# Check if dimension already exists in current levels or new_levels
|
|
for level in levels:
|
|
if not level.is_positional() and level.dim() is d:
|
|
raise DimensionBindError(
|
|
f"expanding dimension {d} already exists in tensor with dims"
|
|
)
|
|
for new_level in new_levels:
|
|
if not new_level.is_positional() and new_level.dim() is d:
|
|
raise DimensionBindError(
|
|
f"expanding dimension {d} already exists in tensor with dims"
|
|
)
|
|
|
|
new_levels.append(DimEntry(d))
|
|
new_sizes.append(d.size)
|
|
new_strides.append(0)
|
|
|
|
# Add existing levels
|
|
new_levels.extend(levels)
|
|
|
|
# Add existing sizes and strides
|
|
orig_sizes = list(data.size())
|
|
orig_strides = list(data.stride())
|
|
new_sizes.extend(orig_sizes)
|
|
new_strides.extend(orig_strides)
|
|
|
|
# Create expanded tensor using as_strided
|
|
expanded_data = data.as_strided(new_sizes, new_strides, data.storage_offset())
|
|
|
|
# Return new tensor with expanded dimensions
|
|
result = Tensor.from_positional(expanded_data, new_levels, info.has_device)
|
|
return result # type: ignore[return-value] # Tensor and torch.Tensor are interchangeable
|
|
|
|
def index(
|
|
self,
|
|
dims: Union[int, Dim, tuple[Union[int, Dim], ...], list[Union[int, Dim]]],
|
|
indices: Union[
|
|
int,
|
|
slice,
|
|
torch.Tensor,
|
|
tuple[Union[int, slice, torch.Tensor], ...],
|
|
list[Union[int, slice, torch.Tensor]],
|
|
],
|
|
) -> _Tensor:
|
|
"""
|
|
Index tensor using first-class dimensions.
|
|
"""
|
|
from ._dim_entry import _match_levels
|
|
from ._getsetitem import getsetitem_flat, invoke_getitem
|
|
from ._wrap import _wrap_dim
|
|
|
|
# Helper to check if obj is a dimpack (tuple/list) and extract items
|
|
def maybe_dimpack(obj: Any, check_first: bool = False) -> tuple[Any, bool]:
|
|
if isinstance(obj, (tuple, list)):
|
|
return list(obj), True
|
|
return None, False
|
|
|
|
def parse_dim_entry(s: Any) -> Any:
|
|
d = _wrap_dim(s, self.ndim, False)
|
|
if d.is_none():
|
|
raise TypeError(f"expected a dimension specifyer but found {repr(s)}")
|
|
return d
|
|
|
|
# Helper for dimension not present errors
|
|
def dim_not_present(d: Any) -> None:
|
|
if d.is_positional():
|
|
raise TypeError(
|
|
f"dimension {d.position() + self.ndim} not in tensor of {self.ndim} dimensions"
|
|
)
|
|
else:
|
|
raise TypeError(f"dimension {repr(d.dim())} not in tensor")
|
|
|
|
dims_list: list[Union[int, Dim]] = []
|
|
indices_list: list[Union[int, slice, torch.Tensor]] = []
|
|
|
|
lhs_list = isinstance(dims, (tuple, list))
|
|
rhs_list = isinstance(indices, (tuple, list))
|
|
|
|
if lhs_list and rhs_list:
|
|
# Type narrowing: we know dims and indices are sequences here
|
|
dims_seq = dims # type: ignore[assignment]
|
|
indices_seq = indices # type: ignore[assignment]
|
|
if len(dims_seq) != len(indices_seq): # type: ignore[arg-type]
|
|
raise TypeError(
|
|
f"dims ({len(dims_seq)}) and indices ({len(indices_seq)}) must have the same length" # type: ignore[arg-type]
|
|
)
|
|
dims_list.extend(dims_seq) # type: ignore[arg-type]
|
|
indices_list.extend(indices_seq) # type: ignore[arg-type]
|
|
else:
|
|
dims_list.append(dims) # type: ignore[arg-type]
|
|
indices_list.append(indices) # type: ignore[arg-type]
|
|
|
|
# Create tensor info
|
|
self_info = TensorInfo.create(self, False, False)
|
|
|
|
new_levels: list[Any] = []
|
|
to_flatten: list[Any] = []
|
|
dims_list_flat = []
|
|
|
|
# Process each dim specification
|
|
for i in range(len(dims_list)):
|
|
m, is_dimpack = maybe_dimpack(dims_list[i], check_first=False)
|
|
if is_dimpack:
|
|
if len(m) == 0:
|
|
dims_list_flat.append(DimEntry()) # Empty dimpack
|
|
continue
|
|
|
|
first = parse_dim_entry(m[0])
|
|
dims_list_flat.append(first)
|
|
|
|
if len(m) == 1:
|
|
continue
|
|
|
|
# Multi-element dimpack requires flattening
|
|
if len(to_flatten) == 0:
|
|
new_levels.extend(self_info.levels)
|
|
|
|
rest = []
|
|
for j in range(1, len(m)):
|
|
d = parse_dim_entry(m[j])
|
|
removed = False
|
|
for k in range(len(new_levels)):
|
|
if new_levels[k] == d:
|
|
new_levels.pop(k)
|
|
removed = True
|
|
break
|
|
if not removed:
|
|
dim_not_present(d)
|
|
rest.append(d)
|
|
|
|
# Find first in new_levels
|
|
first_idx = None
|
|
for k in range(len(new_levels)):
|
|
if new_levels[k] == first:
|
|
first_idx = k
|
|
break
|
|
if first_idx is None:
|
|
dim_not_present(first)
|
|
continue # Skip this iteration if dimension not found
|
|
|
|
for j, r in enumerate(rest):
|
|
new_levels.insert(first_idx + 1 + j, r)
|
|
to_flatten.extend(rest)
|
|
else:
|
|
dims_list_flat.append(parse_dim_entry(dims_list[i]))
|
|
|
|
# Handle dimension flattening if needed
|
|
if len(to_flatten) > 0:
|
|
assert self_info.tensor is not None, (
|
|
"Cannot perform dimension flattening on None tensor"
|
|
)
|
|
rearranged = _match_levels(self_info.tensor, self_info.levels, new_levels)
|
|
sizes = rearranged.size()
|
|
new_sizes: list[Any] = []
|
|
reshape_levels = []
|
|
|
|
for i in range(len(new_levels)):
|
|
if new_levels[i] in to_flatten:
|
|
if len(new_sizes) == 0:
|
|
new_sizes.append(sizes[i])
|
|
else:
|
|
new_sizes[-1] *= sizes[i]
|
|
else:
|
|
new_sizes.append(sizes[i])
|
|
reshape_levels.append(new_levels[i])
|
|
|
|
self_info.tensor = rearranged.reshape(new_sizes)
|
|
self_info.levels = reshape_levels
|
|
|
|
# Check for dimpacks in indices
|
|
has_dimpacks = False
|
|
for idx in indices_list:
|
|
if isinstance(idx, (tuple, list)):
|
|
has_dimpacks = True
|
|
break
|
|
|
|
# Call getsetitem_flat with correct parameters
|
|
info = getsetitem_flat(
|
|
self_info,
|
|
[], # empty input_list
|
|
dims_list_flat, # keys
|
|
indices_list, # values
|
|
has_dimpacks,
|
|
)
|
|
|
|
return invoke_getitem(info)
|
|
|
|
def __repr__(self) -> str:
|
|
tensor, levels, ndim = self._get_tensor(), self._get_levels(), self.ndim
|
|
dims_repr = []
|
|
for l in levels:
|
|
if hasattr(l, "is_positional") and l.is_positional():
|
|
# Convert negative positional to positive: -1 -> ndim-1, -2 -> ndim-2, etc.
|
|
dims_repr.append(l.position() + ndim)
|
|
elif hasattr(l, "dim"):
|
|
dims_repr.append(l.dim())
|
|
elif hasattr(l, "data"):
|
|
dims_repr.append(l.data)
|
|
else:
|
|
dims_repr.append(l)
|
|
return f"{tensor}\nwith dims={tuple(dims_repr)} sizes={tuple(tensor.size())}" # type: ignore[union-attr]
|
|
|
|
|
|
TensorLike = (_Tensor, torch.Tensor)
|
|
|
|
|
|
class Dim(_Tensor):
|
|
_level: int
|
|
_name: str
|
|
_size: int
|
|
_range: Optional[torch.Tensor]
|
|
_batchtensor: Optional[torch.Tensor]
|
|
|
|
def __init__(self, name: str, s: int = -1) -> None:
|
|
global _n_dims_created
|
|
self._name = name
|
|
self._size = s
|
|
self._level = _n_dims_created
|
|
_n_dims_created += 1
|
|
self._range = None
|
|
self._batchtensor = None
|
|
|
|
@property
|
|
def ndim(self) -> int:
|
|
return 1
|
|
|
|
@classmethod
|
|
def check_exact(cls, obj: Any) -> bool:
|
|
return type(obj) is cls
|
|
|
|
@property
|
|
def size(self) -> int:
|
|
if self._size == -1:
|
|
raise ValueError(f"dimension {self._name} is unbound")
|
|
return self._size
|
|
|
|
@size.setter
|
|
def size(self, v: int) -> None:
|
|
if self._size == -1:
|
|
self._size = v
|
|
elif self._size != v:
|
|
raise DimensionBindError(
|
|
f"Dim '{repr(self)}' previously bound to a dimension of size {self._size} "
|
|
f"cannot bind to a dimension of size {v}"
|
|
)
|
|
|
|
@property
|
|
def is_bound(self) -> bool:
|
|
"""Return True if this dimension is bound to a size."""
|
|
return self._size != -1
|
|
|
|
def _get_range(self) -> torch.Tensor:
|
|
"""
|
|
Get a tensor representing the range [0, size) for this dimension.
|
|
|
|
Returns:
|
|
A 1D tensor with values [0, 1, 2, ..., size-1]
|
|
"""
|
|
if self._range is None:
|
|
self._range = torch.arange(self.size)
|
|
return self._range
|
|
|
|
def _get_batchtensor(self) -> torch.Tensor:
|
|
"""
|
|
Get a batched tensor representation of this dimension.
|
|
|
|
Returns:
|
|
A batched tensor created from the range tensor
|
|
"""
|
|
if self._batchtensor is None:
|
|
self._batchtensor = torch._C._functorch._add_batch_dim(
|
|
self._get_range(), 0, self._level
|
|
)
|
|
return self._batchtensor
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation of a Dim object."""
|
|
return self._name
|
|
|
|
# note that Dim comes before tensor because we want the Dim API for things like size to take precedence.
|
|
# Tensor defines format, but we want to print Dims with special formatting
|
|
__format__ = object.__format__
|
|
|
|
|
|
# Somewhat confusingly, an FCD tensor is also called Tensor. This confusion
|
|
# is somewhat intentional, as FCD tensors are intended to be substitutable
|
|
# with regular Tensor (just with some positional dims hidden).
|
|
class Tensor(_Tensor):
|
|
_tensor: Optional[torch.Tensor]
|
|
_batchtensor: Optional[torch.Tensor]
|
|
_levels: list[DimEntry]
|
|
_has_device: bool
|
|
_delayed: Optional[Callable[[], torch.Tensor]]
|
|
_delayed_orig: Optional[Callable]
|
|
_delayed_args: Optional[tuple]
|
|
|
|
@property
|
|
def ndim(self) -> int:
|
|
return sum(1 if l.is_positional() else 0 for l in self._levels)
|
|
|
|
@classmethod
|
|
def check_exact(cls, other: Any) -> bool:
|
|
return type(other) is cls
|
|
|
|
@classmethod
|
|
def from_positional(
|
|
cls, tensor: torch.Tensor, levels: list[DimEntry], has_device: bool
|
|
) -> Union[_Tensor, torch.Tensor]:
|
|
"""
|
|
Create a functorch Tensor from a regular PyTorch tensor with specified dimension levels.
|
|
|
|
This is the primary way to create Tensor objects with first-class dimensions.
|
|
|
|
Args:
|
|
tensor: The underlying PyTorch tensor
|
|
levels: List of DimEntry objects specifying the dimension structure
|
|
has_device: Whether the tensor is on a device (not CPU)
|
|
|
|
Returns:
|
|
A new Tensor instance with the specified dimensions, or a regular torch.Tensor
|
|
if there are no named dimensions
|
|
"""
|
|
seen_dims = 0
|
|
last = 0
|
|
|
|
for i, l in enumerate(levels):
|
|
if l.is_positional():
|
|
# Validate consecutive positional dimensions
|
|
assert last == 0 or last + 1 == l.position(), (
|
|
f"Positional dimensions must be consecutive, got {last} then {l.position()}"
|
|
)
|
|
last = l.position()
|
|
else:
|
|
# This is a named dimension
|
|
seen_dims += 1
|
|
|
|
# Validate final positional dimension
|
|
assert last == 0 or last == -1, (
|
|
f"Final positional dimension must be 0 or -1, got {last}"
|
|
)
|
|
|
|
if not seen_dims:
|
|
return tensor
|
|
|
|
# Create Tensor object with proper level management
|
|
result = cls()
|
|
result._tensor = tensor
|
|
result._levels = levels
|
|
result._has_device = has_device
|
|
result._batchtensor = None # Will be created lazily if needed
|
|
result._delayed = None
|
|
result._delayed_orig = None
|
|
result._delayed_args = None
|
|
|
|
# Validate tensor dimensionality matches levels
|
|
assert tensor.dim() == len(levels), (
|
|
f"Tensor has {tensor.dim()} dimensions but {len(levels)} levels provided"
|
|
)
|
|
|
|
return result
|
|
|
|
@classmethod
|
|
def create_delayed(
|
|
cls, orig: Callable, args: tuple, levels: list[DimEntry], has_device: bool
|
|
) -> _Tensor:
|
|
"""
|
|
Create a delayed tensor that defers the operation until later.
|
|
"""
|
|
result = cls()
|
|
result._tensor = None # Will be computed when needed
|
|
result._levels = levels
|
|
result._has_device = has_device
|
|
result._batchtensor = None
|
|
result._delayed_orig = orig
|
|
result._delayed_args = args
|
|
|
|
# Create delayed evaluation function that unwraps Tensor objects
|
|
def evaluate_delayed() -> torch.Tensor:
|
|
unwrapped_args = []
|
|
for arg in args:
|
|
if hasattr(arg, "_get_tensor"):
|
|
unwrapped_args.append(arg._get_tensor())
|
|
else:
|
|
unwrapped_args.append(arg)
|
|
return orig(*unwrapped_args)
|
|
|
|
result._delayed = evaluate_delayed
|
|
|
|
return result
|
|
|
|
def _get_tensor(self) -> Optional[torch.Tensor]:
|
|
"""Get the underlying tensor, handling delayed operations if needed."""
|
|
if (
|
|
hasattr(self, "_delayed")
|
|
and self._delayed is not None
|
|
and self._tensor is None
|
|
):
|
|
# Execute the delayed operation
|
|
self._tensor = self._delayed()
|
|
# Clear delayed operation to avoid re-execution
|
|
self._delayed = None
|
|
self._delayed_orig = None
|
|
self._delayed_args = None
|
|
return self._tensor
|
|
|
|
def _get_levels(self) -> list[Any]:
|
|
"""Get the dimension levels."""
|
|
return self._levels
|
|
|
|
def _get_has_device(self) -> bool:
|
|
"""Get whether this tensor has device information."""
|
|
return self._has_device
|
|
|
|
def _get_batchtensor(self) -> Optional[torch.Tensor]:
|
|
"""Get the batched tensor representation, creating it lazily if needed."""
|
|
if self._batchtensor is None:
|
|
self._batchtensor = self._add_batch_dims(
|
|
self._get_tensor(), self._get_levels()
|
|
)
|
|
return self._batchtensor
|
|
|
|
def _add_batch_dims(
|
|
self, t: Optional[torch.Tensor], levels_: list[Any]
|
|
) -> Optional[torch.Tensor]:
|
|
levels = list(levels_)
|
|
|
|
while True:
|
|
min_real_index = -1
|
|
min_index = -1
|
|
min_value = float("inf") # INT_MAX equivalent
|
|
i = 0
|
|
r = 0
|
|
|
|
for r, l in enumerate(levels):
|
|
if not l.is_none():
|
|
if not l.is_positional() and l.dim()._level < min_value:
|
|
min_value = l.dim()._level
|
|
min_index = i
|
|
min_real_index = r
|
|
i += 1
|
|
|
|
if min_index == -1:
|
|
return t
|
|
|
|
assert t is not None
|
|
t = torch._C._functorch._add_batch_dim(t, min_index, int(min_value))
|
|
|
|
levels[min_real_index] = DimEntry()
|
|
return None
|
|
|
|
def order(self, *dims: Any) -> _Tensor:
|
|
"""Reorder the dimensions of this tensor."""
|
|
from ._order import order
|
|
|
|
result = order(self, *dims)
|
|
return result # type: ignore[return-value] # Tensor and torch.Tensor are interchangeable
|
|
|
|
|
|
def stack(tensors: Any, new_dim: Any, dim: int = 0) -> _Tensor:
|
|
"""
|
|
Stack tensors along a new dimension.
|
|
|
|
Args:
|
|
tensors: Sequence of tensors to stack
|
|
new_dim: The new Dim to create for stacking
|
|
dim: The dimension position to insert the new dimension (default: 0)
|
|
|
|
Returns:
|
|
Stacked tensor with the new dimension
|
|
"""
|
|
if not tensors:
|
|
raise ValueError("stack expects a non-empty sequence of tensors")
|
|
|
|
# Check if new_dim is a Dim object
|
|
if not isinstance(new_dim, Dim):
|
|
# Fall back to regular torch.stack
|
|
result = torch.stack(tensors, dim=dim)
|
|
return result # type: ignore[return-value]
|
|
|
|
# Collect all result_levels from input tensors
|
|
result_levels = []
|
|
infos = []
|
|
|
|
for t in tensors:
|
|
info = TensorInfo.create(t, ensure_batched=False, ensure_present=False)
|
|
infos.append(info)
|
|
for level in info.levels:
|
|
if level not in result_levels:
|
|
result_levels.append(level)
|
|
|
|
# Set the new_dim size to match number of tensors
|
|
new_dim.size = len(tensors)
|
|
|
|
# Match all tensors to the common level structure using _match_levels
|
|
inputs = []
|
|
for info in infos:
|
|
assert info.tensor is not None, "Cannot stack tensors with None tensor data"
|
|
matched_tensor = _match_levels(info.tensor, info.levels, result_levels)
|
|
inputs.append(matched_tensor)
|
|
|
|
# Calculate ndim and resolve the dim parameter
|
|
ndim = ndim_of_levels(result_levels)
|
|
rawdim = 0
|
|
if dim is not None and not (isinstance(dim, int) and dim == 0):
|
|
from ._wrap import _wrap_dim
|
|
|
|
d = _wrap_dim(dim, ndim, False)
|
|
try:
|
|
idx = result_levels.index(d)
|
|
except ValueError:
|
|
raise TypeError(f"Dimension {dim} does not exist in inputs") from None
|
|
rawdim = idx
|
|
|
|
# Stack tensors at the resolved dimension
|
|
result = torch.stack(inputs, rawdim)
|
|
|
|
# Insert new dimension entry at the correct position
|
|
result_levels.insert(rawdim, DimEntry(new_dim))
|
|
|
|
# Return as a first-class tensor
|
|
tensor_result = Tensor.from_positional(
|
|
result, result_levels, infos[0].has_device if infos else True
|
|
)
|
|
return tensor_result # type: ignore[return-value]
|
|
|
|
|
|
def split(tensor: Any, split_size_or_sections: Any, dim: Any = None) -> tuple:
|
|
"""
|
|
Split tensor along a dimension.
|
|
|
|
Can handle both regular integer sizes and Dim objects for split sizes.
|
|
When Dim objects are used, they get bound to the resulting tensor dimensions.
|
|
"""
|
|
from ._wrap import _wrap_dim
|
|
|
|
# Check if dim is a Dim object
|
|
dim_is_object = isinstance(dim, Dim)
|
|
|
|
# Parse split_size_or_sections
|
|
if isinstance(split_size_or_sections, int):
|
|
# Single integer - use regular split
|
|
if dim_is_object:
|
|
raise TypeError(
|
|
"when dim is specified as a Dim object, split sizes must also be dimensions."
|
|
)
|
|
return _Tensor._torch_function_fallback(
|
|
torch.Tensor.split,
|
|
(type(tensor),),
|
|
(tensor, split_size_or_sections),
|
|
{"dim": dim},
|
|
)
|
|
|
|
# Check if it's a sequence
|
|
sizes = []
|
|
all_dims = True
|
|
all_ints = True
|
|
|
|
for item in split_size_or_sections:
|
|
sizes.append(item)
|
|
if isinstance(item, Dim):
|
|
all_ints = False
|
|
else:
|
|
all_dims = False
|
|
|
|
if all_ints:
|
|
# All integers - use regular split
|
|
if dim_is_object:
|
|
raise TypeError(
|
|
"when dim is specified as a Dim object, split sizes must also be dimensions."
|
|
)
|
|
return _Tensor._torch_function_fallback(
|
|
torch.Tensor.split,
|
|
(type(tensor),),
|
|
(tensor, split_size_or_sections),
|
|
{"dim": dim},
|
|
)
|
|
|
|
if not all_dims:
|
|
raise TypeError("split list must be ints or dims but got a mix")
|
|
|
|
# All are Dim objects - handle first-class dimension split
|
|
self_info = TensorInfo.create(tensor, ensure_batched=False, ensure_present=False)
|
|
ndim = self_info.ndim()
|
|
|
|
if not dim_is_object and ndim == 0:
|
|
raise TypeError("split expects at least a 1-dimension tensor")
|
|
|
|
# Wrap the dimension
|
|
dim_l = _wrap_dim(dim, ndim, False) if dim is not None else DimEntry(-ndim)
|
|
|
|
# Find the index of the dimension in levels
|
|
idx = None
|
|
for i, level in enumerate(self_info.levels):
|
|
if level == dim_l:
|
|
idx = i
|
|
break
|
|
|
|
if idx is None:
|
|
if dim is None:
|
|
dim = 0
|
|
raise TypeError(f"tensor does not contain dimension {dim}")
|
|
|
|
# Calculate split indices
|
|
indices = []
|
|
total_size = 0
|
|
unbound = []
|
|
|
|
for i, size_dim in enumerate(sizes):
|
|
if size_dim.is_bound:
|
|
indices.append(size_dim.size)
|
|
total_size += indices[-1]
|
|
else:
|
|
indices.append(0)
|
|
unbound.append(i)
|
|
|
|
assert self_info.tensor is not None, "Cannot get tensor size on None tensor"
|
|
tensor_size = self_info.tensor.size(idx)
|
|
|
|
# Handle unbound dimensions
|
|
if unbound:
|
|
if total_size > tensor_size:
|
|
raise TypeError(
|
|
f"sizes of target dimensions add up to more ({total_size}) than source dim ({tensor_size})"
|
|
)
|
|
remaining_size = tensor_size - total_size
|
|
chunk_size = (remaining_size + len(unbound) - 1) // len(unbound)
|
|
for u in unbound:
|
|
sz = min(chunk_size, remaining_size)
|
|
sizes[u].size = sz
|
|
indices[u] = sz
|
|
remaining_size -= sz
|
|
elif tensor_size != total_size:
|
|
raise TypeError(
|
|
f"sum of sizes of target dimensions ({total_size}) do not match the source dim ({tensor_size})"
|
|
)
|
|
|
|
# Perform the split
|
|
result_tensors = self_info.tensor.split_with_sizes(indices, idx)
|
|
|
|
# Create result with new levels
|
|
result = []
|
|
new_levels = list(self_info.levels)
|
|
|
|
for i, (result_tensor, size_dim) in enumerate(zip(result_tensors, sizes)):
|
|
new_levels[idx] = DimEntry(size_dim)
|
|
result.append(
|
|
Tensor.from_positional(
|
|
result_tensor, list(new_levels), self_info.has_device
|
|
)
|
|
)
|
|
|
|
return tuple(result)
|
|
|
|
|
|
def cat(tensors: Any, dim: Any, new_dim: Any) -> _Tensor:
|
|
n = dims(1) # Get single Dim instead of tuple
|
|
return stack(tensors, n, dim).index([n, dim], new_dim) # type: ignore[list-item]
|
|
|
|
|
|
class DotPart:
|
|
"""
|
|
Helper class for organizing dimensions in dot products.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.dims: list[DimEntry] = []
|
|
self.total_size = 1
|
|
|
|
def append(self, dim_entry: Any) -> None:
|
|
"""Add a dimension entry to this part."""
|
|
self.dims.append(dim_entry)
|
|
if not dim_entry.is_positional():
|
|
self.total_size *= dim_entry.dim().size
|
|
|
|
|
|
def dot_prepare(parts: list[DotPart], tensor_info: TensorInfo) -> torch.Tensor:
|
|
"""
|
|
Prepare tensor for dot product by matching levels and reshaping.
|
|
"""
|
|
new_levels = []
|
|
needs_reshape = False
|
|
|
|
for part in parts:
|
|
if len(part.dims) != 1:
|
|
needs_reshape = True
|
|
new_levels.extend(part.dims)
|
|
|
|
if tensor_info.tensor is None:
|
|
raise RuntimeError("Cannot perform dot product on None tensor")
|
|
result = _match_levels(tensor_info.tensor, tensor_info.levels, new_levels)
|
|
|
|
if not needs_reshape:
|
|
return result
|
|
|
|
# Reshape for matrix operations
|
|
view = [part.total_size for part in parts]
|
|
return result.reshape(view)
|
|
|
|
|
|
def dot_finish(parts: list[DotPart], result_tensor: torch.Tensor) -> Tensor:
|
|
"""
|
|
Finish dot product by reshaping result and creating Tensor.
|
|
"""
|
|
result_levels = []
|
|
needs_reshape = False
|
|
|
|
for part in parts:
|
|
if len(part.dims) != 1:
|
|
needs_reshape = True
|
|
result_levels.extend(part.dims)
|
|
|
|
if needs_reshape:
|
|
new_size = []
|
|
for level in result_levels:
|
|
new_size.append(level.dim().size)
|
|
result_tensor = result_tensor.reshape(new_size)
|
|
|
|
tensor_result = Tensor.from_positional(result_tensor, result_levels, True)
|
|
return tensor_result # type: ignore[return-value]
|
|
|
|
|
|
def dot(lhs: Any, rhs: Any, sum_dims: Any) -> Union[_Tensor, torch.Tensor]:
|
|
"""
|
|
Perform dot product between two tensors along specified dimensions.
|
|
|
|
Args:
|
|
lhs: Left-hand side tensor
|
|
rhs: Right-hand side tensor
|
|
sum_dims: Dimensions to sum over (contract)
|
|
|
|
Returns:
|
|
Result of dot product
|
|
"""
|
|
# Get tensor info
|
|
lhs_info = TensorInfo.create(lhs, ensure_batched=False, ensure_present=False)
|
|
rhs_info = TensorInfo.create(rhs, ensure_batched=False, ensure_present=False)
|
|
|
|
if not (lhs_info and rhs_info):
|
|
# Fall back to regular operations
|
|
return torch.matmul(lhs, rhs)
|
|
|
|
assert lhs_info.tensor is not None and rhs_info.tensor is not None, (
|
|
"Cannot perform dot product on None tensors"
|
|
)
|
|
|
|
lhs_strides = lhs_info.tensor.stride()
|
|
rhs_strides = rhs_info.tensor.stride()
|
|
|
|
# Create dot parts for different dimension categories
|
|
lro_dims = DotPart() # Left-right-output (batch dims)
|
|
lo_dims = DotPart() # Left-output only
|
|
ro_dims = DotPart() # Right-output only
|
|
lr_dims = DotPart() # Left-right (contracted dims)
|
|
|
|
def insert_dim(d: Any, lhs_idx: Any, rhs_idx: Any) -> None:
|
|
"""Insert dimension into appropriate part based on stride pattern."""
|
|
reduced = d in sum_dims
|
|
lhs_stride = lhs_strides[lhs_idx] if lhs_idx is not None else 0
|
|
rhs_stride = rhs_strides[rhs_idx] if rhs_idx is not None else 0
|
|
|
|
if reduced:
|
|
lr_dims.append(d)
|
|
else:
|
|
if (lhs_stride == 0) == (rhs_stride == 0):
|
|
lro_dims.append(d) # Both have or both lack this dim
|
|
elif lhs_stride != 0:
|
|
lo_dims.append(d) # Only lhs has this dim
|
|
else:
|
|
ro_dims.append(d) # Only rhs has this dim
|
|
|
|
# Track which rhs dimensions we've seen
|
|
rhs_seen = [False] * len(rhs_info.levels)
|
|
|
|
# Process lhs dimensions
|
|
for i, lhs_level in enumerate(lhs_info.levels):
|
|
rhs_idx = None
|
|
for j, rhs_level in enumerate(rhs_info.levels):
|
|
if lhs_level == rhs_level:
|
|
rhs_idx = j
|
|
rhs_seen[j] = True
|
|
break
|
|
|
|
insert_dim(lhs_level, i, rhs_idx)
|
|
|
|
# Process remaining rhs dimensions
|
|
for i, rhs_level in enumerate(rhs_info.levels):
|
|
if not rhs_seen[i]:
|
|
insert_dim(rhs_level, None, i)
|
|
|
|
# Validate sum dimensions exist
|
|
if len(lr_dims.dims) != len(sum_dims):
|
|
for d in sum_dims:
|
|
if d not in lhs_info.levels and d not in rhs_info.levels:
|
|
raise ValueError(f"summing over non-existent dimension {d}")
|
|
|
|
# Prepare tensors and perform matrix multiplication
|
|
if len(lro_dims.dims) != 0:
|
|
# Batched matrix multiply
|
|
lhs_tensor = dot_prepare([lro_dims, lo_dims, lr_dims], lhs_info)
|
|
rhs_tensor = dot_prepare([lro_dims, lr_dims, ro_dims], rhs_info)
|
|
result = torch.bmm(lhs_tensor, rhs_tensor)
|
|
return dot_finish([lro_dims, lo_dims, ro_dims], result)
|
|
else:
|
|
# Regular matrix multiply
|
|
lhs_tensor = dot_prepare([lo_dims, lr_dims], lhs_info)
|
|
rhs_tensor = dot_prepare([lr_dims, ro_dims], rhs_info)
|
|
result = torch.mm(lhs_tensor, rhs_tensor)
|
|
return dot_finish([lo_dims, ro_dims], result)
|
|
|
|
|
|
from functorch.dim._wrap import _wrap
|
|
from functorch.dim.wrap_type import wrap_type
|
|
|
|
|
|
wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__)
|
|
del _Tensor.ndim
|
|
|
|
|
|
def index(self: Any, positions: Any, dims: Any) -> _Tensor:
|
|
"""
|
|
Index a regular tensor by binding specified positions to dims.
|
|
|
|
This converts a regular tensor to a first-class tensor by binding
|
|
the specified positional dimensions to Dim objects.
|
|
|
|
Args:
|
|
positions: Tuple of dimension positions to bind
|
|
dims: Dim objects or tuple of Dim objects to bind to
|
|
|
|
Returns:
|
|
First-class tensor with specified dimensions bound
|
|
"""
|
|
# If this is already a first-class tensor (_Tensor), call its index method directly
|
|
if isinstance(self, _Tensor):
|
|
return _Tensor.index(self, positions, dims)
|
|
|
|
# Convert regular tensor to first-class tensor
|
|
info = TensorInfo.create(self, ensure_batched=False, ensure_present=False)
|
|
|
|
# Create the first-class tensor
|
|
assert info.tensor is not None, "Cannot index None tensor"
|
|
result = Tensor.from_positional(info.tensor, info.levels, info.has_device)
|
|
|
|
# Now call the index method on the first-class tensor
|
|
# Cast result to _Tensor for the method call
|
|
return _Tensor.index(result, positions, dims) # type: ignore[arg-type]
|
|
|
|
|
|
def _def(name: str, *args: Any, **kwargs: Any) -> None:
|
|
orig = getattr(torch.Tensor, name)
|
|
setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
|
|
|
|
|
|
_def("mean")
|
|
_def("sum")
|
|
_def("all")
|
|
_def("amax")
|
|
_def("amin")
|
|
_def("aminmax")
|
|
_def("any")
|
|
_def("count_nonzero")
|
|
_def("logsumexp")
|
|
_def("nanmean")
|
|
_def("nansum")
|
|
_def("prod")
|
|
_def("std", keepdim_offset=2)
|
|
_def("var", keepdim_offset=2)
|
|
_def("max", single_dim=True)
|
|
_def("min", single_dim=True)
|
|
_def("argmax", single_dim=True)
|
|
_def("argmin", single_dim=True)
|
|
_def("kthvalue", single_dim=True)
|
|
_def("median", single_dim=True)
|
|
_def("nanmedian", single_dim=True)
|
|
_def("mode", single_dim=True)
|
|
_def("sort", reduce=False)
|
|
_def("argsort", reduce=False)
|
|
_def("unbind", single_dim=True)
|
|
_def("chunk", dim_offset=1, reduce=False)
|
|
_def("cummax", single_dim=True, reduce=False)
|
|
_def("cummin", single_dim=True, reduce=False)
|
|
_def("cumprod", single_dim=True, reduce=False)
|
|
_def("cumprod_", single_dim=True, reduce=False)
|
|
_def("cumsum", single_dim=True, reduce=False)
|
|
_def("cumsum_", single_dim=True, reduce=False)
|
|
_def("logcumsumexp", single_dim=True, reduce=False)
|
|
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
|
|
_def("softmax", single_dim=True, reduce=False)
|
|
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
|
|
|
|
# stuff to handle in the future, because they require special
|
|
# binding logic for dims
|
|
# cross
|
|
# diag_embed
|
|
# diagonal
|
|
# diagonal_scatter
|
|
# diff
|
|
# nanquantile
|
|
# quantile
|
|
# roll
|
|
# rot90
|
|
# topk (new dimes on output)
|
|
# should these all be subsumed by inplace indexing?
|
|
# index_add_
|
|
# index_add
|
|
# index_copy
|
|
# index_copy_
|
|
# index_fill
|
|
# index_fill_
|
|
# index_select
|
|
# scatter
|
|
# scatter_
|
|
# scatter_add
|
|
# scatter_add_
|
|
# scatter_reduce
|