mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torchdim Python port (#160236)
The big semantic change (and the reason for this port) is that we no longer monkeypatch Tensor with torchdim's special methods. The new algorithm for handling dispatch is that we first land in `__torch_function__` and we see if a special FCD implementation needs to be dispatch to first, and if there is nothing we fallback to the standard level strategy. Because there is no longer C binding equivalent of classes, we've condensed _C.Dim and Dim together, and similar for Tensor. This resulted in some bugs as the Python API is sometimes different from the C API. I've attempted to disambiguate these but there may still be mistakes (many early bugs were due to this problem). Dim and DimEntry are especially painful as Dim must abide by Tensor equality semantics, but is pointer equality in C (DimEntry doesn't have this problem). Another difference between C/Python that is subtle is we no longer get implicit conversions from Dim to DimEntry, this also caused some bugs. Much of the mechanical porting work was done by claude code. I have a separate PR that deletes functorch._C, but it was useful having dim.cpp to point claude at it so I haven't done it in this PR. From a reviewing perspective, I need to re-review that I didn't forget to port anything, some noticeably missing "small" things are patched_dim_method. I am still in progress of carefully doing a side-by-side review of ports; "simplifications" from claude code were also a major source of bugs. There are two major feature gaps in the implementation: - DelayedTensor and dot handling are not implemented yet. This should be reasonably easy, just need to do it. However, for the purposes of sharded propagation it is actually better not to reconstruct matmuls. - Splitting dimensions with an index like `[x, y]` doesn't work. The problem is that `__getitem__` interprets this as advanced indexing and sends the list to torch.tensor to turn into a tensor, instead of being eligible for `__torch_function__`. I think I might need to hard code a special case for this or something? Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160236 Approved by: https://github.com/zdevito, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
2887f3fde4
commit
97eb7a281d
@ -746,12 +746,14 @@ These compilers and language have syntax and semantics that resemble the loop-le
|
||||
|
||||
Dimension objects are just an extension of the existing PyTorch tensors and eager semantics, so there is no friction switching between normal Python code and code that uses them. However, since loops over the dimensions are defined implicitly, they can still execute in Python with good performance compared to explicit loops. Furthermore, with dimension objects, a tensors containing dimensions can compute through code that is oblivious to the dimension such as batching examples. There is no need to separate code into 'compiled' vs 'eager'.
|
||||
|
||||
In this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries.
|
||||
In this way, first-class dims are a way of adapting the nicer syntax of these array compilers and languages to eager numpy-style libraries. Note, however, that first class dimensions are not natively compiled, so if you write code that performs many outer products with the expectation of it being fused, you will generally not get good performance or memory use (except for matrix-multiply-like patterns specifically.)
|
||||
|
||||
|
||||
Performance Expectations
|
||||
========================
|
||||
First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel. The C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. In the future, the implementation can incorporate more fusion optimization to further improve performance of this style of code.
|
||||
First-class dimensions are not a compiler. They provide syntax for existing PyTorch operations such as advanced indexing that is easier to read and write. For large sized tensors, the performance of any statements including them will be the same as using the already existing operations. An important exception is the pattern matching of products and summation, where performance will be improved by issuing to a matrix-multiply kernel.
|
||||
|
||||
Originally, there was a C++ implementation of dimensions adds a small overhead of around 2us on top of PyTorch's normal overhead of 8us to each function that uses them. However, this implementation had some manual memory managemetn bugs and was not kept up to date with CPython updates. The latest Python implementation is two orders of magnitude slower due to CPU overhead; for overhead sensitive applications you should compile the code to eliminate this overhead.
|
||||
|
||||
|
||||
## License
|
||||
|
File diff suppressed because it is too large
Load Diff
127
functorch/dim/_dim_entry.py
Normal file
127
functorch/dim/_dim_entry.py
Normal file
@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from . import Dim
|
||||
|
||||
import torch # noqa: TC002
|
||||
|
||||
|
||||
# NB: The old code represented dimension was from as negative number, so we
|
||||
# follow this convention even though it shouldn't be necessary now
|
||||
class DimEntry:
|
||||
# The dimension this is from the rhs, or a FCD
|
||||
data: Union[Dim, int]
|
||||
|
||||
def __init__(self, data: Union[Dim, int, None] = None) -> None:
|
||||
from . import Dim
|
||||
|
||||
if type(data) is int:
|
||||
assert data < 0
|
||||
elif data is None:
|
||||
data = 0
|
||||
else:
|
||||
assert isinstance(data, Dim)
|
||||
self.data = data
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, DimEntry):
|
||||
return False
|
||||
# Use 'is' for Dim objects to avoid triggering __torch_function__
|
||||
# Use '==' only for positional (int) comparisons
|
||||
if self.is_positional() and other.is_positional():
|
||||
# Both are positional (ints)
|
||||
return self.data == other.data
|
||||
elif not self.is_positional() and not other.is_positional():
|
||||
# Both are Dim objects - use 'is' to avoid __eq__
|
||||
return self.data is other.data
|
||||
else:
|
||||
# One is positional, one is Dim - they can't be equal
|
||||
return False
|
||||
|
||||
def is_positional(self) -> bool:
|
||||
return type(self.data) is int and self.data < 0
|
||||
|
||||
def is_none(self) -> bool:
|
||||
# Use isinstance to check for Dim objects, avoid triggering __torch_function__
|
||||
from . import Dim
|
||||
|
||||
if isinstance(self.data, Dim):
|
||||
# This is a Dim object, it can't be "none" (which is represented by 0)
|
||||
return False
|
||||
else:
|
||||
# This is an int or other type
|
||||
return self.data == 0
|
||||
|
||||
def position(self) -> int:
|
||||
assert isinstance(self.data, int)
|
||||
return self.data
|
||||
|
||||
def dim(self) -> Dim:
|
||||
assert not isinstance(self.data, int)
|
||||
return self.data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.data)
|
||||
|
||||
|
||||
def ndim_of_levels(levels: Sequence[DimEntry]) -> int:
|
||||
r = 0
|
||||
for l in levels:
|
||||
if l.is_positional():
|
||||
r += 1
|
||||
return r
|
||||
|
||||
|
||||
def _match_levels(
|
||||
tensor: torch.Tensor,
|
||||
from_levels: list[DimEntry],
|
||||
to_levels: list[DimEntry],
|
||||
drop_levels: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reshape a tensor to match target levels using as_strided.
|
||||
|
||||
Args:
|
||||
tensor: Input tensor to reshape
|
||||
from_levels: Current levels of the tensor
|
||||
to_levels: Target levels to match
|
||||
drop_levels: If True, missing dimensions are assumed to have stride 0
|
||||
|
||||
Returns:
|
||||
Reshaped tensor
|
||||
"""
|
||||
if from_levels == to_levels:
|
||||
return tensor
|
||||
|
||||
sizes = tensor.size()
|
||||
strides = tensor.stride()
|
||||
|
||||
if not drop_levels:
|
||||
assert len(from_levels) <= len(to_levels), (
|
||||
"Cannot expand dimensions without drop_levels"
|
||||
)
|
||||
|
||||
new_sizes = []
|
||||
new_strides = []
|
||||
|
||||
for level in to_levels:
|
||||
# Find index of this level in from_levels
|
||||
try:
|
||||
idx = from_levels.index(level)
|
||||
except ValueError:
|
||||
# Level not found in from_levels
|
||||
if level.is_positional():
|
||||
new_sizes.append(1)
|
||||
else:
|
||||
new_sizes.append(level.dim().size)
|
||||
new_strides.append(0)
|
||||
else:
|
||||
new_sizes.append(sizes[idx])
|
||||
new_strides.append(strides[idx])
|
||||
|
||||
return tensor.as_strided(new_sizes, new_strides, tensor.storage_offset())
|
139
functorch/dim/_enable_all_layers.py
Normal file
139
functorch/dim/_enable_all_layers.py
Normal file
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from ._dim_entry import DimEntry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Dim, Tensor
|
||||
|
||||
|
||||
class EnableAllLayers:
|
||||
"""
|
||||
RAII-style context manager for enabling functorch vmap layers.
|
||||
It manages the creation and cleanup of functorch dynamic layers.
|
||||
|
||||
This is probably one of the more algorithmically important parts of first
|
||||
class dims. Intuitively, FCD can be thought of as another way of using
|
||||
vmap, where you don't actually have to vmap at the top level, instead the
|
||||
vmaps are implicitly determined by inspecting the bound dimensions on the
|
||||
FCD tensors involved in a compute (this is similar to our concept of
|
||||
non-lexical modes that we spent a long time talking about years ago). But
|
||||
under the hood you still need to actually enable the vmap mode. So once
|
||||
FCD has determined all of the dims we are batching over, it needs to
|
||||
enable all those layers so functorch can actually apply the batching
|
||||
rules. Therefore enable all layers!
|
||||
"""
|
||||
|
||||
levels_start: int
|
||||
levels_to_dim: list[Dim]
|
||||
|
||||
def __init__(self, levels: list[DimEntry]):
|
||||
"""
|
||||
Initialize and push dynamic layers for all first-class dimensions.
|
||||
|
||||
Args:
|
||||
levels: List of dimension entries to create layers for
|
||||
"""
|
||||
|
||||
from . import Dim
|
||||
|
||||
self.levels_start = 0
|
||||
self.levels_to_dim = []
|
||||
|
||||
for l in levels:
|
||||
if not l.is_positional():
|
||||
d = l.dim()
|
||||
assert isinstance(d, Dim)
|
||||
self.levels_to_dim.append(d)
|
||||
|
||||
# Sort by level for stable ordering
|
||||
self.levels_to_dim.sort(key=lambda d: d._level)
|
||||
|
||||
def __enter__(self) -> EnableAllLayers: # noqa: PYI034
|
||||
# Create functorch dynamic layers
|
||||
for i, dim in enumerate(self.levels_to_dim):
|
||||
batch_size = dim.size
|
||||
level = torch._C._functorch._vmap_increment_nesting(batch_size, "different")
|
||||
if i == 0:
|
||||
self.levels_start = level
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Clean up dynamic layers in reverse order."""
|
||||
to_remove = self.levels_start + len(self.levels_to_dim) - 1
|
||||
for i in range(len(self.levels_to_dim)):
|
||||
popped = torch._C._functorch._vmap_decrement_nesting()
|
||||
assert popped == to_remove - i, (
|
||||
f"Expected layer {to_remove - i}, got {popped}"
|
||||
)
|
||||
|
||||
def from_batched(self, batchedtensor: torch.Tensor, has_device: bool) -> Tensor:
|
||||
"""
|
||||
Create a Tensor from a batched tensor by unwrapping functorch layers.
|
||||
|
||||
Args:
|
||||
batchedtensor: Batched tensor from functorch operation
|
||||
has_device: Whether tensor has device info
|
||||
|
||||
Returns:
|
||||
Tensor with appropriate levels
|
||||
"""
|
||||
# Create positional levels for base dimensions
|
||||
levels: list[DimEntry] = []
|
||||
for i in range(-batchedtensor.dim(), 0):
|
||||
levels.append(DimEntry(i))
|
||||
|
||||
tensor = batchedtensor
|
||||
|
||||
while torch._C._functorch.is_batchedtensor(tensor):
|
||||
level = torch._C._functorch.maybe_get_level(tensor)
|
||||
assert level is not None
|
||||
assert level >= self.levels_start and level < self.levels_start + len(
|
||||
self.levels_to_dim
|
||||
)
|
||||
dim = DimEntry(self.levels_to_dim[level - self.levels_start])
|
||||
bdim = torch._C._functorch.maybe_get_bdim(tensor)
|
||||
assert bdim is not None
|
||||
levels.insert(bdim, dim)
|
||||
tensor = torch._C._functorch.get_unwrapped(tensor)
|
||||
|
||||
from . import Tensor
|
||||
|
||||
result = Tensor()
|
||||
result._tensor = tensor
|
||||
result._batchtensor = batchedtensor
|
||||
result._has_device = has_device
|
||||
result._levels = levels
|
||||
return result
|
||||
|
||||
def inplace_update_layers(
|
||||
self, batchtensor: torch.Tensor, levels: list[DimEntry]
|
||||
) -> None:
|
||||
"""
|
||||
Update the levels of a batched tensor in place.
|
||||
|
||||
This requires the _maybe_unsafe_set_level binding that we'll add to functorch.
|
||||
|
||||
Args:
|
||||
batchtensor: Batched tensor to update
|
||||
levels: New levels to set
|
||||
"""
|
||||
# Check if tensor is batched
|
||||
if not torch._C._functorch.is_batchedtensor(batchtensor):
|
||||
return
|
||||
|
||||
impl = batchtensor
|
||||
|
||||
for i in reversed(range(len(self.levels_to_dim))):
|
||||
if impl is None:
|
||||
break
|
||||
|
||||
if any(l == DimEntry(self.levels_to_dim[i]) for l in levels):
|
||||
# This is very interesting! The level on batch tensor is
|
||||
# meaningless! We set it RIGHT before we go into vmap
|
||||
torch._C._functorch._maybe_unsafe_set_level(impl, self.levels_start + i)
|
||||
impl = torch._C._functorch.get_unwrapped(impl)
|
561
functorch/dim/_getsetitem.py
Normal file
561
functorch/dim/_getsetitem.py
Normal file
@ -0,0 +1,561 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ._dim_entry import _match_levels, DimEntry
|
||||
from ._tensor_info import TensorInfo
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Dim
|
||||
|
||||
|
||||
def _safe_index(lst: list, item: Any) -> Optional[int]:
|
||||
"""
|
||||
Helper function to find index of item in list.
|
||||
|
||||
For DimEntry objects, uses __eq__ comparison which properly handles
|
||||
both positional and Dim entries.
|
||||
|
||||
Returns the index if found, None if not found.
|
||||
"""
|
||||
for i, list_item in enumerate(lst):
|
||||
# Use == for DimEntry objects as they have proper __eq__ implementation
|
||||
if isinstance(item, DimEntry) and isinstance(list_item, DimEntry):
|
||||
if list_item == item:
|
||||
return i
|
||||
elif list_item is item:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexingInfo:
|
||||
can_call_original: bool = False
|
||||
advanced_indexing: bool = False
|
||||
self_tensor: Optional[torch.Tensor] = None
|
||||
flat_inputs: list[Any] = field(default_factory=list)
|
||||
result_levels: list[DimEntry] = field(default_factory=list)
|
||||
has_device: bool = False
|
||||
|
||||
|
||||
def has_dims(obj: Any) -> bool:
|
||||
"""
|
||||
Check if an object has first-class dimensions.
|
||||
|
||||
This function checks if the object is either a Dim or a functorch Tensor
|
||||
that has first-class dimensions, using the proper check_exact methods.
|
||||
"""
|
||||
from . import Dim, Tensor
|
||||
|
||||
return Dim.check_exact(obj) or Tensor.check_exact(obj)
|
||||
|
||||
|
||||
def _bind_dims_to_size(sz: int, sd: int, dims: list, nsz: list, nsd: list) -> None:
|
||||
"""
|
||||
Bind dimensions to size and calculate proper strides for dim packs.
|
||||
"""
|
||||
from . import DimensionBindError
|
||||
|
||||
rhs_prod = 1
|
||||
for i, dim in enumerate(dims):
|
||||
if not dim.is_bound:
|
||||
# Check for multiple unbound dimensions
|
||||
for j in range(i + 1, len(dims)):
|
||||
if not dims[j].is_bound:
|
||||
raise DimensionBindError(
|
||||
f"cannot infer the sizes of two dimensions at once {dim!r} and {dims[j]!r}"
|
||||
)
|
||||
rhs_prod *= dims[j].size
|
||||
|
||||
# Calculate the size for this unbound dimension
|
||||
if sz % rhs_prod != 0:
|
||||
tup = tuple(dim.size if dim.is_bound else "?" for dim in dims)
|
||||
raise DimensionBindError(
|
||||
f"inferred dimension does not evenly fit into larger dimension: {sz} vs {tup}"
|
||||
)
|
||||
|
||||
inferred_size = sz // rhs_prod
|
||||
dim.size = inferred_size
|
||||
rhs_prod = sz
|
||||
break
|
||||
else:
|
||||
rhs_prod *= dim.size
|
||||
|
||||
# Final validation that dimensions match
|
||||
if rhs_prod != sz:
|
||||
tup = tuple(dims)
|
||||
raise DimensionBindError(
|
||||
f"Dimension sizes to do not match ({sz} != {rhs_prod}) when matching dimension pack {tup}"
|
||||
)
|
||||
|
||||
# Calculate new sizes and strides for each dimension in the pack
|
||||
# First calculate all strides by iterating in reverse
|
||||
new_strides = [0] * len(dims)
|
||||
current_stride = sd
|
||||
for i in reversed(range(len(dims))):
|
||||
new_strides[i] = current_stride
|
||||
current_stride *= dims[i].size
|
||||
|
||||
# Then append sizes and strides in forward order
|
||||
for i in range(len(dims)):
|
||||
nsz.append(dims[i].size)
|
||||
nsd.append(new_strides[i])
|
||||
|
||||
|
||||
def slice_to_tuple(flat_inputs: list) -> tuple:
|
||||
return tuple(flat_inputs)
|
||||
|
||||
|
||||
def extractIndices(index: Any, indices: list) -> bool:
|
||||
if isinstance(index, tuple): # mpy::tuple_view::check
|
||||
indices.extend(index)
|
||||
return True
|
||||
elif isinstance(index, torch.Tensor): # THPVariable_Check
|
||||
indices.append(index)
|
||||
return False
|
||||
elif not hasattr(index, "__iter__") or isinstance(
|
||||
index, (str, bytes)
|
||||
): # !mpy::is_sequence
|
||||
indices.append(index)
|
||||
return False
|
||||
|
||||
# Handle sequence case (list)
|
||||
if isinstance(index, list):
|
||||
if len(index) >= 32:
|
||||
indices.extend(index)
|
||||
return True
|
||||
|
||||
# Check each item in the sequence
|
||||
for item in index:
|
||||
if (
|
||||
isinstance(item, (torch.Tensor, slice))
|
||||
or hasattr(item, "__iter__")
|
||||
or item is ...
|
||||
or item is None
|
||||
or has_dims(item)
|
||||
):
|
||||
indices.extend(index)
|
||||
return True
|
||||
|
||||
# If we got here, treat as single index
|
||||
indices.append(index)
|
||||
return False
|
||||
|
||||
# Default case
|
||||
indices.append(index)
|
||||
return False
|
||||
|
||||
|
||||
def getitem(cls: Any, func: Any, types: Any, args: Any, kwargs: Any) -> Any:
|
||||
self = args[0]
|
||||
index = args[1]
|
||||
|
||||
iinfo = getsetitem(self, index, has_dims(self))
|
||||
if iinfo.can_call_original:
|
||||
# Call original tensor __getitem__ directly, bypassing __torch_function__
|
||||
return torch.Tensor.__getitem__(self, index)
|
||||
|
||||
return invoke_getitem(iinfo)
|
||||
|
||||
|
||||
def setitem(self: Any, index: Any, rhs: Any) -> None:
|
||||
"""Set values in tensor using first-class dimensions."""
|
||||
from . import DimensionBindError, TensorInfo
|
||||
|
||||
iinfo = getsetitem(self, index, has_dims(self) or has_dims(rhs))
|
||||
|
||||
if iinfo.can_call_original:
|
||||
# Call original tensor __setitem__ directly, bypassing __torch_function__
|
||||
torch._C.TensorBase.__setitem__(self, index, rhs)
|
||||
return
|
||||
|
||||
# Handle RHS tensor with dimensions
|
||||
rhs_info = TensorInfo.create(rhs, False, False)
|
||||
|
||||
if rhs_info:
|
||||
# Check that rhs dimensions are compatible with result dimensions
|
||||
for l in rhs_info.levels:
|
||||
if not l.is_positional():
|
||||
# Find this dimension in result levels
|
||||
found = False
|
||||
for result_level in iinfo.result_levels:
|
||||
if (
|
||||
not result_level.is_positional()
|
||||
and result_level.dim() is l.dim()
|
||||
):
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
# Create tuple representation of result levels for error message
|
||||
result_dims: list[Union[int, Dim]] = []
|
||||
for rl in iinfo.result_levels:
|
||||
if rl.is_positional():
|
||||
result_dims.append(rl.position())
|
||||
else:
|
||||
result_dims.append(rl.dim())
|
||||
|
||||
raise DimensionBindError(
|
||||
f"rhs of setitem contains dimension {l.dim()!r} which is not in the dimension on the left "
|
||||
f"({tuple(result_dims)!r})"
|
||||
)
|
||||
|
||||
# Match RHS tensor to result levels
|
||||
assert rhs_info.tensor is not None, "Cannot match levels on None tensor"
|
||||
matched_rhs = _match_levels(
|
||||
rhs_info.tensor, rhs_info.levels, iinfo.result_levels
|
||||
)
|
||||
else:
|
||||
matched_rhs = rhs
|
||||
|
||||
# For advanced indexing with dimensions, we need special handling
|
||||
if iinfo.advanced_indexing:
|
||||
# Use advanced indexing - the flat_inputs already contain matched tensors
|
||||
tup = slice_to_tuple(iinfo.flat_inputs)
|
||||
if iinfo.self_tensor is None:
|
||||
raise RuntimeError("Cannot setitem on None tensor")
|
||||
torch._C.TensorBase.__setitem__(iinfo.self_tensor, tup, matched_rhs)
|
||||
else:
|
||||
# Simple copy operation
|
||||
if iinfo.self_tensor is None:
|
||||
raise RuntimeError("Cannot copy to None tensor")
|
||||
iinfo.self_tensor.copy_(matched_rhs)
|
||||
|
||||
|
||||
def invoke_getitem(iinfo: IndexingInfo) -> Any:
|
||||
if iinfo.advanced_indexing:
|
||||
self_tensor = iinfo.self_tensor
|
||||
tup = slice_to_tuple(iinfo.flat_inputs)
|
||||
if self_tensor is None:
|
||||
raise RuntimeError("Cannot getitem on None tensor")
|
||||
rtensor = self_tensor[tup]
|
||||
else:
|
||||
rtensor = iinfo.self_tensor # type: ignore[assignment]
|
||||
if rtensor is None:
|
||||
raise RuntimeError("Cannot getitem on None tensor")
|
||||
# rtensor is now guaranteed to be not None
|
||||
|
||||
# Create a Tensor with the proper dimensions using the class method
|
||||
from . import Tensor
|
||||
|
||||
return Tensor.from_positional(rtensor, iinfo.result_levels, iinfo.has_device)
|
||||
|
||||
|
||||
def getsetitem(self: Any, index: Any, tensors_have_dims: bool) -> IndexingInfo:
|
||||
from . import DimList # Import DimList for type checking
|
||||
|
||||
can_call_original_getitem = not tensors_have_dims
|
||||
|
||||
input_list = []
|
||||
if has_dims(index):
|
||||
input_list.append(index)
|
||||
else:
|
||||
is_sequence = extractIndices(index, input_list)
|
||||
# nothing about first class dims here, fallback to getitem
|
||||
if can_call_original_getitem and not is_sequence:
|
||||
return IndexingInfo(can_call_original=True)
|
||||
|
||||
# Calculate how many dimensions have been indexed in order to compute the
|
||||
# size of ... or expand a potentially unbound dimension list.
|
||||
dims_indexed = 0
|
||||
expanding_object = -1
|
||||
unbound_dim_list = None
|
||||
dimlists = [] # Track DimList positions for later processing
|
||||
|
||||
def check_expanding(i: int) -> None:
|
||||
nonlocal expanding_object
|
||||
if expanding_object != -1:
|
||||
from . import DimensionBindError
|
||||
|
||||
raise DimensionBindError(
|
||||
f"at most one ... or unbound dimension list can exist in indexing list but found 2 at offsets "
|
||||
f"{expanding_object} and {i}"
|
||||
)
|
||||
expanding_object = i
|
||||
|
||||
def is_dimpack(s: Any) -> bool:
|
||||
from . import Dim
|
||||
|
||||
return (
|
||||
isinstance(s, (tuple, list))
|
||||
and len(s) > 0
|
||||
and all(Dim.check_exact(item) for item in s)
|
||||
)
|
||||
|
||||
has_dimpacks_or_none = False
|
||||
for i, s in enumerate(input_list):
|
||||
if has_dims(s):
|
||||
can_call_original_getitem = False
|
||||
dims_indexed += 1
|
||||
elif s is ...:
|
||||
check_expanding(i)
|
||||
elif isinstance(s, DimList):
|
||||
can_call_original_getitem = False
|
||||
if not s.is_bound:
|
||||
check_expanding(i)
|
||||
unbound_dim_list = s
|
||||
else:
|
||||
dims_indexed += len(s._dims)
|
||||
dimlists.append(i)
|
||||
elif s is None:
|
||||
has_dimpacks_or_none = True
|
||||
elif is_dimpack(s):
|
||||
can_call_original_getitem = False
|
||||
has_dimpacks_or_none = True
|
||||
dims_indexed += 1
|
||||
else:
|
||||
dims_indexed += 1
|
||||
|
||||
# Early return if we can use original getitem
|
||||
if can_call_original_getitem:
|
||||
return IndexingInfo(can_call_original=True)
|
||||
|
||||
self_info = TensorInfo.create(self, False, True)
|
||||
total_dims = len(self_info.levels) # Total dimensions (positional + named)
|
||||
if dims_indexed > total_dims:
|
||||
raise ValueError(
|
||||
f"at least {dims_indexed} indices were supplied but the tensor only has {total_dims} dimensions"
|
||||
)
|
||||
|
||||
# Expand any unbound dimension list, or expand ... into individual : slices.
|
||||
expanding_dims = total_dims - dims_indexed
|
||||
if expanding_object != -1:
|
||||
if unbound_dim_list is not None:
|
||||
# Bind unbound dimension list to the expanding dimensions
|
||||
unbound_dim_list.bind_len(expanding_dims)
|
||||
else:
|
||||
# Expand ... into slice(None) objects
|
||||
no_slices = [slice(None)] * expanding_dims
|
||||
input_list = (
|
||||
input_list[:expanding_object]
|
||||
+ no_slices
|
||||
+ input_list[expanding_object + 1 :]
|
||||
)
|
||||
|
||||
# Flatten out any dimensions stored in dimlist elements directly into the inputs
|
||||
# Process in reverse order to maintain indices
|
||||
for i in range(len(dimlists) - 1, -1, -1):
|
||||
idx = dimlists[i]
|
||||
|
||||
# We added more elements to input because of ...
|
||||
# so we need to also adjust the index to get back to where the
|
||||
# dimlist existed
|
||||
if (
|
||||
unbound_dim_list is None
|
||||
and expanding_object != -1
|
||||
and idx > expanding_object
|
||||
):
|
||||
idx += expanding_dims
|
||||
|
||||
dl = input_list[idx]
|
||||
|
||||
# PRIVATE here naughty
|
||||
input_list = input_list[:idx] + dl._dims + input_list[idx + 1 :]
|
||||
|
||||
return getsetitem_flat(self_info, input_list, [], [], has_dimpacks_or_none)
|
||||
|
||||
|
||||
def getsetitem_flat(
|
||||
self_info: TensorInfo,
|
||||
input_list: list,
|
||||
keys: list[DimEntry],
|
||||
values: list,
|
||||
has_dimpacks_or_none: bool,
|
||||
) -> IndexingInfo:
|
||||
from . import Dim
|
||||
|
||||
# Track dimension usage
|
||||
seen_dims: list[Any] = []
|
||||
seen_dims_nuses: list[int] = []
|
||||
|
||||
def add_dim(dim: Any) -> None:
|
||||
# Use safe indexing to avoid triggering __torch_function__ on Dim objects
|
||||
idx = _safe_index(seen_dims, dim)
|
||||
if idx is not None:
|
||||
seen_dims_nuses[idx] += 1
|
||||
else:
|
||||
seen_dims.append(dim)
|
||||
seen_dims_nuses.append(1)
|
||||
|
||||
flat_inputs = []
|
||||
tensor_inputs: list[Any] = []
|
||||
device_holding_tensor = None
|
||||
|
||||
def append_flat_handle(handle: Any) -> None:
|
||||
flat_inputs.append(handle)
|
||||
tensor_inputs.append(None)
|
||||
|
||||
def append_tensor_input(ti: TensorInfo) -> None:
|
||||
flat_inputs.append(None)
|
||||
tensor_inputs.append(ti)
|
||||
nonlocal device_holding_tensor
|
||||
if ti.has_device and device_holding_tensor is None:
|
||||
device_holding_tensor = ti.tensor
|
||||
|
||||
nsz = []
|
||||
nsd = []
|
||||
if self_info.tensor is None:
|
||||
raise RuntimeError("Cannot get size/stride on None tensor")
|
||||
sz = self_info.tensor.size()
|
||||
sd = self_info.tensor.stride()
|
||||
|
||||
def append_size(i: int) -> None:
|
||||
if has_dimpacks_or_none:
|
||||
nsz.append(sz[i])
|
||||
nsd.append(sd[i])
|
||||
|
||||
input_it = input_list[:]
|
||||
|
||||
def parse_nones() -> None:
|
||||
nonlocal input_it
|
||||
while input_it and input_it[0] is None:
|
||||
append_flat_handle(slice(None))
|
||||
nsz.append(1)
|
||||
nsd.append(0)
|
||||
input_it = input_it[1:]
|
||||
|
||||
def append_item(i: int, arg: Any) -> None:
|
||||
if Dim.check_exact(arg):
|
||||
d = arg
|
||||
if d._size == -1:
|
||||
d.size = sz[i]
|
||||
add_dim(d)
|
||||
append_size(i)
|
||||
append_flat_handle(arg)
|
||||
return
|
||||
|
||||
info = TensorInfo.create(arg, False, False)
|
||||
if info:
|
||||
append_size(i)
|
||||
append_tensor_input(info)
|
||||
for level in info.levels:
|
||||
if not level.is_positional():
|
||||
add_dim(level.dim())
|
||||
return
|
||||
|
||||
if has_dimpacks_or_none:
|
||||
if isinstance(arg, (tuple, list)) and all(Dim.check_exact(d) for d in arg):
|
||||
# dim pack
|
||||
dim_pack = list(arg)
|
||||
for d in dim_pack:
|
||||
add_dim(d)
|
||||
append_flat_handle(d)
|
||||
_bind_dims_to_size(sz[i], sd[i], dim_pack, nsz, nsd)
|
||||
return
|
||||
|
||||
append_size(i)
|
||||
append_flat_handle(arg)
|
||||
|
||||
# Match indexing expressions with tensor dimensions
|
||||
for i, level in enumerate(self_info.levels):
|
||||
# Use safe indexing to avoid triggering __torch_function__ on DimEntry comparisons
|
||||
idx = _safe_index(keys, level)
|
||||
if idx is not None:
|
||||
append_item(i, values[idx])
|
||||
else:
|
||||
if level.is_positional():
|
||||
parse_nones()
|
||||
if not input_it:
|
||||
append_flat_handle(slice(None))
|
||||
append_size(i)
|
||||
else:
|
||||
arg = input_it[0]
|
||||
input_it = input_it[1:]
|
||||
append_item(i, arg)
|
||||
else:
|
||||
add_dim(level.dim())
|
||||
append_flat_handle(level.dim())
|
||||
append_size(i)
|
||||
|
||||
parse_nones()
|
||||
|
||||
# Restride tensor if needed
|
||||
if has_dimpacks_or_none and nsz:
|
||||
if self_info.tensor is None:
|
||||
raise RuntimeError("Cannot restride None tensor")
|
||||
self_tensor = self_info.tensor.as_strided(
|
||||
nsz, nsd, self_info.tensor.storage_offset()
|
||||
)
|
||||
else:
|
||||
self_tensor = self_info.tensor
|
||||
|
||||
# Determine result shape and indexing requirements
|
||||
result_levels: list[Any] = []
|
||||
index_levels = []
|
||||
tensor_insert_point = -1
|
||||
requires_getindex = False
|
||||
|
||||
def mark_tensor_index() -> None:
|
||||
nonlocal tensor_insert_point
|
||||
if tensor_insert_point == -1:
|
||||
tensor_insert_point = len(result_levels)
|
||||
elif tensor_insert_point != len(result_levels):
|
||||
tensor_insert_point = 0
|
||||
|
||||
for i, inp in enumerate(flat_inputs):
|
||||
if tensor_inputs[i] is not None:
|
||||
requires_getindex = True
|
||||
mark_tensor_index()
|
||||
for level in tensor_inputs[i].levels:
|
||||
if level not in index_levels:
|
||||
index_levels.append(level)
|
||||
elif Dim.check_exact(inp):
|
||||
d = inp
|
||||
# Use safe indexing to avoid triggering __torch_function__
|
||||
dim_idx = _safe_index(seen_dims, d)
|
||||
assert dim_idx is not None, f"Dim {d} not found in seen_dims"
|
||||
if seen_dims_nuses[dim_idx] == 1:
|
||||
flat_inputs[i] = slice(None)
|
||||
result_levels.append(DimEntry(d))
|
||||
else:
|
||||
requires_getindex = True
|
||||
flat_inputs[i] = None
|
||||
tensor_inputs[i] = TensorInfo(
|
||||
d._get_range(), [DimEntry(d)], False, None
|
||||
)
|
||||
if DimEntry(d) not in index_levels:
|
||||
index_levels.append(DimEntry(d))
|
||||
mark_tensor_index()
|
||||
else:
|
||||
if inp != slice(None):
|
||||
requires_getindex = True
|
||||
if not isinstance(inp, int):
|
||||
result_levels.append(DimEntry(-1))
|
||||
|
||||
# Insert indexing dimensions at first tensor use point
|
||||
if tensor_insert_point != -1:
|
||||
for level in reversed(index_levels):
|
||||
result_levels.insert(tensor_insert_point, level)
|
||||
|
||||
# Match tensors to indexing shape
|
||||
if requires_getindex:
|
||||
for i in range(len(flat_inputs)):
|
||||
if tensor_inputs[i] is not None:
|
||||
t = tensor_inputs[i].tensor
|
||||
assert t is not None, "TensorInfo should have valid tensor data"
|
||||
if (
|
||||
not tensor_inputs[i].has_device
|
||||
and device_holding_tensor is not None
|
||||
):
|
||||
t = t.to(device_holding_tensor.device)
|
||||
flat_inputs[i] = _match_levels(t, tensor_inputs[i].levels, index_levels)
|
||||
|
||||
# Number positional dimensions correctly
|
||||
seen_positionals = 0
|
||||
for i in reversed(range(len(result_levels))):
|
||||
if result_levels[i].is_positional():
|
||||
seen_positionals += 1
|
||||
result_levels[i] = DimEntry(-seen_positionals)
|
||||
|
||||
return IndexingInfo(
|
||||
can_call_original=False,
|
||||
advanced_indexing=requires_getindex,
|
||||
self_tensor=self_tensor,
|
||||
flat_inputs=flat_inputs,
|
||||
result_levels=result_levels,
|
||||
has_device=self_info.has_device,
|
||||
)
|
214
functorch/dim/_order.py
Normal file
214
functorch/dim/_order.py
Normal file
@ -0,0 +1,214 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch # noqa: TC002
|
||||
|
||||
from ._dim_entry import _match_levels, DimEntry, ndim_of_levels
|
||||
|
||||
|
||||
def _wrap_dim(arg: Any, orig_ndim: int, allow_none: bool = True) -> DimEntry:
|
||||
"""
|
||||
Convert various dimension representations to DimEntry.
|
||||
|
||||
Args:
|
||||
arg: The argument to convert (Dim, int, or other)
|
||||
orig_ndim: Original number of dimensions
|
||||
allow_none: Whether to allow None values
|
||||
|
||||
Returns:
|
||||
DimEntry representation of the dimension
|
||||
"""
|
||||
from . import Dim
|
||||
|
||||
if arg is None and allow_none:
|
||||
return DimEntry() # None entry
|
||||
elif isinstance(arg, Dim):
|
||||
return DimEntry(arg)
|
||||
elif isinstance(arg, int):
|
||||
if arg < 0:
|
||||
pos = arg
|
||||
else:
|
||||
pos = arg - orig_ndim
|
||||
return DimEntry(pos)
|
||||
else:
|
||||
return DimEntry()
|
||||
|
||||
|
||||
def order(
|
||||
tensor_or_dim: Union[torch.Tensor, Any], *dims: Union[Any, Sequence[Any]]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Reorder the dimensions of a tensor or create a tensor from a dimension.
|
||||
|
||||
It allows reordering tensor dimensions using first-class dimensions and
|
||||
positional indices.
|
||||
|
||||
Args:
|
||||
tensor_or_dim: Input tensor with first-class dimensions, or a Dim object
|
||||
*dims: Dimensions or sequences of dimensions specifying the new order
|
||||
|
||||
Returns:
|
||||
Tensor with reordered dimensions
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from functorch.dim import dims
|
||||
>>> batch, channel, height, width = dims(4)
|
||||
>>> x = torch.randn(2, 3, 4, 5)[batch, channel, height, width]
|
||||
>>> # Reorder to [height, width, batch, channel]
|
||||
>>> y = order(x, height, width, batch, channel)
|
||||
"""
|
||||
from . import Dim, DimList, Tensor
|
||||
|
||||
# Handle first argument - tensor or dimension
|
||||
if isinstance(tensor_or_dim, Tensor):
|
||||
# First-class tensor
|
||||
orig_levels = tensor_or_dim._levels[:]
|
||||
data = tensor_or_dim._tensor
|
||||
has_device = tensor_or_dim._has_device
|
||||
elif isinstance(tensor_or_dim, Dim):
|
||||
# Single dimension - create range tensor
|
||||
orig_levels = [DimEntry(tensor_or_dim)]
|
||||
data = tensor_or_dim._get_range()
|
||||
has_device = False
|
||||
else:
|
||||
raise ValueError("First argument must be a Tensor or Dim object")
|
||||
|
||||
flat_positional_dims = []
|
||||
to_flatten = [] # List of (start_index, length) pairs for flattening
|
||||
levels = orig_levels[:]
|
||||
|
||||
orig_ndim = ndim_of_levels(levels)
|
||||
|
||||
def append_dim(d: DimEntry) -> None:
|
||||
"""Add a dimension to the reordering, removing it from available levels."""
|
||||
try:
|
||||
idx = levels.index(d)
|
||||
except ValueError:
|
||||
idx = None
|
||||
if idx is None:
|
||||
if d.is_positional():
|
||||
raise ValueError(
|
||||
f"tensor has {orig_ndim} positional dimensions, but {d.position() + orig_ndim} specified, "
|
||||
f"or it was specified twice"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"tensor does not contain dim {d.dim()} or it was specified twice"
|
||||
)
|
||||
|
||||
levels[idx] = DimEntry()
|
||||
flat_positional_dims.append(d)
|
||||
|
||||
n_new_positional = 0
|
||||
|
||||
# Process each dimension argument
|
||||
for arg in dims:
|
||||
entry = _wrap_dim(arg, orig_ndim, False)
|
||||
if not entry.is_none():
|
||||
append_dim(entry)
|
||||
n_new_positional += 1
|
||||
elif isinstance(arg, DimList):
|
||||
# Handle DimList
|
||||
for dim in arg._dims:
|
||||
append_dim(DimEntry(dim))
|
||||
n_new_positional += 1
|
||||
else:
|
||||
# Handle sequences of dimensions for flattening
|
||||
n_new_positional += 1
|
||||
if not hasattr(arg, "__iter__"):
|
||||
raise ValueError("expected a Dim, List[Dim], or Sequence[Dim]")
|
||||
|
||||
# Convert to list to get length
|
||||
seq = list(arg)
|
||||
to_flatten.append((len(flat_positional_dims), len(seq)))
|
||||
|
||||
for item in seq:
|
||||
entry = _wrap_dim(item, orig_ndim, False)
|
||||
if entry.is_none():
|
||||
raise ValueError("expected a Dim or int")
|
||||
append_dim(entry)
|
||||
|
||||
# Build new level ordering
|
||||
insert_point = -1
|
||||
new_levels: list[DimEntry] = []
|
||||
|
||||
# Add remaining (non-reordered) levels, finding insertion point for new dimensions
|
||||
for level in levels:
|
||||
if level.is_none():
|
||||
continue
|
||||
if level.is_positional():
|
||||
if insert_point == -1:
|
||||
insert_point = len(new_levels)
|
||||
new_levels.extend(flat_positional_dims)
|
||||
new_levels.append(level)
|
||||
|
||||
# If no positional dimensions found, append new dims at the end
|
||||
if insert_point == -1:
|
||||
insert_point = len(new_levels)
|
||||
new_levels.extend(flat_positional_dims)
|
||||
|
||||
# Match tensor to new level structure
|
||||
assert data is not None, "Cannot reorder None tensor"
|
||||
ndata = _match_levels(data, orig_levels, new_levels)
|
||||
|
||||
# Handle dimension flattening if requested
|
||||
if to_flatten:
|
||||
# Now build the reshape target
|
||||
view_shape = []
|
||||
sizes = ndata.size()
|
||||
|
||||
# Add dimensions before the reordered ones
|
||||
for i in range(insert_point):
|
||||
view_shape.append(sizes[i])
|
||||
|
||||
# Process flattening groups
|
||||
i = 0
|
||||
for start_idx, length in to_flatten:
|
||||
# Add individual dims before this flattening group
|
||||
while i < start_idx:
|
||||
view_shape.append(sizes[insert_point + i])
|
||||
i += 1
|
||||
|
||||
# Flatten the group
|
||||
new_size = 1
|
||||
for j in range(length):
|
||||
new_size *= sizes[insert_point + i + j]
|
||||
view_shape.append(new_size)
|
||||
i += length
|
||||
|
||||
# Add remaining individual dims
|
||||
while i < len(flat_positional_dims):
|
||||
view_shape.append(sizes[insert_point + i])
|
||||
i += 1
|
||||
|
||||
# Add dimensions after the reordered ones
|
||||
for i in range(insert_point + len(flat_positional_dims), len(levels)):
|
||||
view_shape.append(sizes[i])
|
||||
|
||||
# Update levels by removing flattened dimensions
|
||||
n_to_remove = len(flat_positional_dims) - n_new_positional
|
||||
if n_to_remove > 0:
|
||||
# Remove flattened levels
|
||||
new_levels = (
|
||||
new_levels[:insert_point] + new_levels[insert_point + n_to_remove :]
|
||||
)
|
||||
|
||||
ndata = ndata.reshape(view_shape)
|
||||
|
||||
# Renumber positional dimensions (negative indexing from the right)
|
||||
seen = 0
|
||||
for i in range(len(new_levels) - 1, -1, -1):
|
||||
if new_levels[i].is_positional() or (
|
||||
i >= insert_point and i < insert_point + n_new_positional
|
||||
):
|
||||
seen -= 1
|
||||
new_levels[i] = DimEntry(seen)
|
||||
|
||||
result = Tensor.from_positional(ndata, new_levels, has_device)
|
||||
return result # type: ignore[return-value]
|
67
functorch/dim/_py_inst_decoder.py
Normal file
67
functorch/dim/_py_inst_decoder.py
Normal file
@ -0,0 +1,67 @@
|
||||
import dis
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
class _PyInstDecoder:
|
||||
"""
|
||||
Decodes Python bytecode instructions to extract variable names
|
||||
"""
|
||||
|
||||
def __init__(self, code_object: Any, lasti: int) -> None:
|
||||
self.code_object = code_object
|
||||
self.instructions = list(dis.get_instructions(code_object))
|
||||
self.offset = self._find_instruction_index(lasti)
|
||||
|
||||
def _find_instruction_index(self, lasti: int) -> int:
|
||||
"""Find instruction index corresponding to lasti (byte offset)."""
|
||||
# Find the instruction at or before lasti
|
||||
# This should find the CALL instruction, not the next one
|
||||
best_idx = 0
|
||||
for i, instr in enumerate(self.instructions):
|
||||
if instr.offset <= lasti:
|
||||
best_idx = i
|
||||
else:
|
||||
break
|
||||
return best_idx
|
||||
|
||||
def next(self) -> None:
|
||||
"""Advance to the next instruction."""
|
||||
self.offset += 1
|
||||
|
||||
def opcode(self) -> Optional[str]:
|
||||
"""Get the opcode name of the current instruction."""
|
||||
if self.offset < len(self.instructions):
|
||||
return self.instructions[self.offset].opname
|
||||
return None
|
||||
|
||||
def oparg(self) -> int:
|
||||
"""Get the argument of the current instruction."""
|
||||
if self.offset < len(self.instructions):
|
||||
return self.instructions[self.offset].arg or 0
|
||||
return 0
|
||||
|
||||
def name(self) -> Optional[str]:
|
||||
"""
|
||||
Extract variable name from current instruction.
|
||||
"""
|
||||
opname = self.opcode()
|
||||
if not opname:
|
||||
return None
|
||||
|
||||
names = None
|
||||
if opname in ("STORE_NAME", "STORE_GLOBAL"):
|
||||
names = self.code_object.co_names
|
||||
elif opname == "STORE_FAST":
|
||||
names = self.code_object.co_varnames
|
||||
elif opname == "STORE_DEREF":
|
||||
names = self.code_object.co_cellvars
|
||||
if not names:
|
||||
names = self.code_object.co_freevars
|
||||
else:
|
||||
return None
|
||||
|
||||
arg = self.oparg()
|
||||
if names and 0 <= arg < len(names):
|
||||
return names[arg]
|
||||
|
||||
return None
|
68
functorch/dim/_tensor_info.py
Normal file
68
functorch/dim/_tensor_info.py
Normal file
@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._dim_entry import DimEntry
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInfo:
|
||||
tensor: Optional[torch.Tensor]
|
||||
levels: list[DimEntry]
|
||||
has_device: bool
|
||||
batchedtensor: Optional[torch.Tensor]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from ._dim_entry import DimEntry
|
||||
|
||||
assert all(isinstance(l, DimEntry) for l in self.levels)
|
||||
|
||||
def ndim(self) -> int:
|
||||
from ._dim_entry import ndim_of_levels
|
||||
|
||||
return ndim_of_levels(self.levels)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self.tensor is not None
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
h: Any, ensure_batched: bool = True, ensure_present: bool = True
|
||||
) -> TensorInfo:
|
||||
from . import Dim, DimEntry, Tensor
|
||||
|
||||
if Tensor.check_exact(h):
|
||||
# functorch Tensor with first-class dimensions
|
||||
return TensorInfo(
|
||||
h._get_tensor(),
|
||||
h._get_levels(),
|
||||
h._get_has_device(),
|
||||
h._get_batchtensor() if ensure_batched else None,
|
||||
)
|
||||
elif Dim.check_exact(h):
|
||||
# For Dim objects, only get range/batchtensor if needed and dimension is bound
|
||||
tensor = h._get_range() if h.is_bound else None
|
||||
batchtensor = (
|
||||
h._get_batchtensor() if ensure_batched and h.is_bound else None
|
||||
)
|
||||
return TensorInfo(
|
||||
tensor,
|
||||
[DimEntry(h)],
|
||||
False,
|
||||
batchtensor,
|
||||
)
|
||||
elif isinstance(h, torch.Tensor):
|
||||
# Plain torch tensor - create positional levels
|
||||
levels = []
|
||||
for i in range(-h.dim(), 0):
|
||||
levels.append(DimEntry(i))
|
||||
return TensorInfo(h, levels, True, h)
|
||||
else:
|
||||
if ensure_present:
|
||||
raise ValueError("expected a tensor object")
|
||||
return TensorInfo(None, [], False, None)
|
263
functorch/dim/_wrap.py
Normal file
263
functorch/dim/_wrap.py
Normal file
@ -0,0 +1,263 @@
|
||||
"""
|
||||
Python implementation of function wrapping functionality for functorch.dim.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
from ._dim_entry import DimEntry
|
||||
from ._enable_all_layers import EnableAllLayers
|
||||
from ._tensor_info import TensorInfo
|
||||
|
||||
|
||||
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Handle tensor conversion for torch function integration."""
|
||||
return tensor
|
||||
|
||||
|
||||
class WrappedOperator:
|
||||
"""
|
||||
This class wraps PyTorch operations to support first-class dimensions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
|
||||
):
|
||||
self.orig = orig
|
||||
self.wrapper_implementation = wrapper_implementation
|
||||
self.name = getattr(orig, "__name__", "")
|
||||
self.doc = getattr(orig, "__doc__", None)
|
||||
self.dim_name = dim_name
|
||||
|
||||
self.is_pointwise = False
|
||||
self.dim_offset = 0
|
||||
self.keepdim_offset = 1
|
||||
self.single_dim = False
|
||||
self.reduce = True
|
||||
|
||||
# Update docstring if we have a dim_name
|
||||
if self.doc and self.dim_name:
|
||||
self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
|
||||
|
||||
def function(self) -> Callable:
|
||||
"""Create a wrapped function that calls our wrapper implementation."""
|
||||
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
return self.wrapper_implementation(self, *args, **kwargs)
|
||||
|
||||
# Copy metadata using functools.update_wrapper for just __name__ and __doc__
|
||||
functools.update_wrapper(
|
||||
wrapped_func, self.orig, assigned=("__name__",), updated=()
|
||||
)
|
||||
wrapped_func.__doc__ = self.doc
|
||||
|
||||
return wrapped_func
|
||||
|
||||
|
||||
def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
|
||||
"""Convert single dimension specification to DimEntry object."""
|
||||
from . import Dim
|
||||
|
||||
if isinstance(dim, Dim):
|
||||
if keepdim:
|
||||
raise ValueError("cannot preserve first-class dimensions with keepdim=True")
|
||||
return DimEntry(dim)
|
||||
elif isinstance(dim, int):
|
||||
i = dim
|
||||
while i >= 0:
|
||||
i -= ndim
|
||||
return DimEntry(i)
|
||||
else:
|
||||
return DimEntry()
|
||||
|
||||
|
||||
def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
|
||||
"""Convert dimension specification to list of DimEntry objects."""
|
||||
de = _wrap_dim(dim, ndim, keepdim)
|
||||
result = []
|
||||
if not de.is_none():
|
||||
result.append(de)
|
||||
else:
|
||||
for d in dim:
|
||||
result.append(_wrap_dim(d, ndim, keepdim))
|
||||
return result
|
||||
|
||||
|
||||
def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
This is the core method that handles dimension-aware operations.
|
||||
"""
|
||||
if not args:
|
||||
raise ValueError("Expected at least one argument (self)")
|
||||
|
||||
# Get dimension argument
|
||||
dim_arg = kwargs.get(wrapper.dim_name)
|
||||
if dim_arg is None and wrapper.dim_offset < len(args):
|
||||
# Try to get dim from positional args (accounting for self at index 0)
|
||||
dim_idx = wrapper.dim_offset + 1
|
||||
if dim_idx < len(args):
|
||||
dim_arg = args[dim_idx]
|
||||
|
||||
# If no dimension argument provided, fall back to standard functorch handling
|
||||
if dim_arg is None:
|
||||
info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
|
||||
if not info:
|
||||
return wrapper.orig(*args, **kwargs)
|
||||
|
||||
with EnableAllLayers(info.levels) as guard:
|
||||
assert info.batchedtensor is not None
|
||||
guard.inplace_update_layers(info.batchedtensor, info.levels)
|
||||
new_args = list(args)
|
||||
new_args[0] = handle_from_tensor(info.batchedtensor)
|
||||
result = wrapper.orig(*new_args, **kwargs)
|
||||
return guard.from_batched(result, info.has_device)
|
||||
|
||||
# Handle dimension-aware operation
|
||||
info = TensorInfo.create(args[0])
|
||||
if not info:
|
||||
return wrapper.orig(*args, **kwargs)
|
||||
|
||||
# Check for keepdim parameter
|
||||
keepdim = False
|
||||
if wrapper.reduce:
|
||||
keepdim_arg = kwargs.get("keepdim")
|
||||
if keepdim_arg is None and wrapper.keepdim_offset < len(args):
|
||||
keepdim_idx = wrapper.keepdim_offset + 1
|
||||
if keepdim_idx < len(args):
|
||||
keepdim_arg = args[keepdim_idx]
|
||||
if keepdim_arg is not None:
|
||||
keepdim = bool(keepdim_arg)
|
||||
|
||||
# Wrap dimensions
|
||||
ndim = info.ndim()
|
||||
dims = _wrap_dims(dim_arg, ndim, keepdim)
|
||||
|
||||
# Convert dimensions to indices and validate
|
||||
dim_indices: list[int] = []
|
||||
seen = [False] * len(info.levels)
|
||||
|
||||
for d in dims:
|
||||
midx = None
|
||||
for i, level in enumerate(info.levels):
|
||||
if level == d:
|
||||
midx = i
|
||||
break
|
||||
|
||||
if midx is None:
|
||||
# Try to match by position/name more flexibly
|
||||
for i, level in enumerate(info.levels):
|
||||
if hasattr(level, "matches") and level.matches(d):
|
||||
midx = i
|
||||
break
|
||||
|
||||
if midx is None:
|
||||
level_strs = [str(level) for level in info.levels]
|
||||
raise ValueError(
|
||||
f"Tensor with dimensions {level_strs} does not contain {d}"
|
||||
)
|
||||
|
||||
seen[midx] = True
|
||||
dim_indices.append(midx)
|
||||
|
||||
# Determine new levels after reduction
|
||||
new_levels = []
|
||||
if wrapper.reduce and not keepdim:
|
||||
for i, level in enumerate(info.levels):
|
||||
if not seen[i]:
|
||||
new_levels.append(level)
|
||||
else:
|
||||
new_levels = info.levels[:]
|
||||
|
||||
# Create dimension indices for the original function
|
||||
if len(dim_indices) == 1:
|
||||
py_indices: Any = dim_indices[0]
|
||||
else:
|
||||
py_indices = tuple(dim_indices)
|
||||
|
||||
# Update arguments
|
||||
new_args = list(args)
|
||||
new_kwargs = kwargs.copy()
|
||||
assert info.tensor is not None
|
||||
new_args[0] = handle_from_tensor(info.tensor)
|
||||
|
||||
# Update dimension argument
|
||||
if wrapper.dim_name in new_kwargs:
|
||||
new_kwargs[wrapper.dim_name] = py_indices
|
||||
else:
|
||||
dim_idx = wrapper.dim_offset + 1
|
||||
if dim_idx < len(new_args):
|
||||
new_args = list(new_args)
|
||||
new_args[dim_idx] = py_indices
|
||||
|
||||
# Call original function
|
||||
result = wrapper.orig(*new_args, **new_kwargs)
|
||||
|
||||
# Wrap results
|
||||
def wrap_result(obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
from . import Tensor
|
||||
|
||||
return Tensor.from_positional(obj, new_levels, info.has_device)
|
||||
return obj
|
||||
|
||||
return tree_map(wrap_result, result)
|
||||
|
||||
|
||||
def _wrap(
|
||||
orig: Callable,
|
||||
dim_offset: Optional[int] = None,
|
||||
keepdim_offset: Optional[int] = None,
|
||||
dim_name: Optional[str] = None,
|
||||
single_dim: Optional[bool] = None,
|
||||
reduce: Optional[bool] = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Wrap a PyTorch function to support first-class dimensions.
|
||||
|
||||
Args:
|
||||
orig: Original function to wrap
|
||||
dim_offset: Offset for dimension argument (default: 0)
|
||||
keepdim_offset: Offset for keepdim argument (default: 1)
|
||||
dim_name: Name of dimension parameter (default: "dim")
|
||||
single_dim: Whether function takes single dimension (default: False)
|
||||
reduce: Whether function reduces dimensions (default: True)
|
||||
"""
|
||||
dim_name = dim_name or "dim"
|
||||
|
||||
wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
|
||||
|
||||
if dim_offset is not None:
|
||||
wrapper.dim_offset = dim_offset
|
||||
if keepdim_offset is not None:
|
||||
wrapper.keepdim_offset = keepdim_offset
|
||||
if single_dim is not None:
|
||||
wrapper.single_dim = single_dim
|
||||
if reduce is not None:
|
||||
wrapper.reduce = reduce
|
||||
|
||||
return wrapper.function()
|
||||
|
||||
|
||||
def call_torch_function(
|
||||
wrapper: WrappedOperator,
|
||||
func: Callable,
|
||||
types: tuple,
|
||||
args: tuple = (),
|
||||
kwargs: Optional[dict] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Handle __torch_function__ calls for wrapped operators.
|
||||
"""
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from . import _Tensor
|
||||
|
||||
# Use the torch function mechanism from _Tensor
|
||||
return _Tensor.__torch_function__(func, types, args, kwargs)
|
@ -6,11 +6,14 @@
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
|
||||
def magic_trace(
|
||||
output: str = "trace.fxt", magic_trace_cache: str = "/tmp/magic-trace"
|
||||
) -> Generator[None, None, None]:
|
||||
pid = os.getpid()
|
||||
if not os.path.exists(magic_trace_cache):
|
||||
print(f"Downloading magic_trace to: {magic_trace_cache}")
|
||||
@ -26,6 +29,7 @@ def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
|
||||
subprocess.run(["chmod", "+x", magic_trace_cache])
|
||||
args = [magic_trace_cache, "attach", "-pid", str(pid), "-o", output]
|
||||
p = subprocess.Popen(args, stderr=subprocess.PIPE, encoding="utf-8")
|
||||
assert p.stderr is not None
|
||||
while True:
|
||||
x = p.stderr.readline()
|
||||
print(x)
|
||||
@ -36,7 +40,8 @@ def magic_trace(output="trace.fxt", magic_trace_cache="/tmp/magic-trace"):
|
||||
finally:
|
||||
p.send_signal(signal.SIGINT)
|
||||
r = p.wait()
|
||||
print(p.stderr.read())
|
||||
p.stderr.close()
|
||||
if p.stderr is not None:
|
||||
print(p.stderr.read())
|
||||
p.stderr.close()
|
||||
if r != 0:
|
||||
raise ValueError(f"magic_trace exited abnormally: {r}")
|
||||
|
@ -1,15 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functorch._C import dim
|
||||
|
||||
|
||||
tree_flatten = dim.tree_flatten
|
||||
|
||||
|
||||
def tree_map(fn, tree):
|
||||
vs, unflatten = tree_flatten(tree)
|
||||
return unflatten(fn(v) for v in vs)
|
@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
from types import (
|
||||
BuiltinMethodType,
|
||||
FunctionType,
|
||||
@ -11,11 +12,8 @@ from types import (
|
||||
MethodDescriptorType,
|
||||
WrapperDescriptorType,
|
||||
)
|
||||
from typing import Any, Callable
|
||||
|
||||
from functorch._C import dim as _C
|
||||
|
||||
|
||||
_wrap_method = _C._wrap_method
|
||||
|
||||
FUNC_TYPES = (
|
||||
FunctionType,
|
||||
@ -26,14 +24,24 @@ FUNC_TYPES = (
|
||||
PROPERTY_TYPES = (GetSetDescriptorType, property)
|
||||
|
||||
|
||||
def wrap_type(to_patch, pattern, __torch_function__):
|
||||
wrap_method = _wrap_method
|
||||
def _py_wrap_method(orig: Callable, __torch_function__: Callable) -> Callable:
|
||||
def impl(*args: Any, **kwargs: Any) -> Any:
|
||||
return __torch_function__(orig, None, args, kwargs)
|
||||
|
||||
all = {}
|
||||
# Copy metadata using functools.update_wrapper for just __name__ and __doc__
|
||||
functools.update_wrapper(impl, orig, assigned=("__name__", "__doc__"), updated=())
|
||||
|
||||
return impl
|
||||
|
||||
|
||||
def wrap_type(to_patch: Any, pattern: type, __torch_function__: Callable) -> None:
|
||||
wrap_method = _py_wrap_method
|
||||
|
||||
all: dict[str, Any] = {}
|
||||
for t in reversed(pattern.mro()[:-1]): # skip object
|
||||
all.update(t.__dict__)
|
||||
|
||||
def wrap_attr(orig):
|
||||
def wrap_attr(orig: Any) -> property:
|
||||
return property(wrap_method(orig.__get__, __torch_function__))
|
||||
|
||||
for name, obj in all.items():
|
||||
|
@ -4,7 +4,7 @@ import functools
|
||||
from typing import Callable, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from functorch._C import dim as _C
|
||||
from functorch.dim import dims # noqa: F401
|
||||
|
||||
from ._parsing import (
|
||||
_ellipsis,
|
||||
@ -20,8 +20,6 @@ if TYPE_CHECKING:
|
||||
|
||||
__all__ = ["rearrange"]
|
||||
|
||||
dims = _C.dims
|
||||
|
||||
|
||||
@functools.lru_cache(256)
|
||||
def _create_rearrange_callable(
|
||||
|
Reference in New Issue
Block a user