Files
pytorch/functorch/dim/_wrap.py
Yuanyuan Chen b2953f5643 [9/N] Apply ruff UP035 rule (#165515)
This is follow-up of #165214 to continue applying ruff UP035 rule to the code base.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165515
Approved by: https://github.com/Lucaskabela
2025-10-17 00:09:51 +00:00

268 lines
8.1 KiB
Python

"""
Python implementation of function wrapping functionality for functorch.dim.
"""
from __future__ import annotations
import functools
from typing import Any, Optional, TYPE_CHECKING
import torch
from torch.utils._pytree import tree_map
from ._dim_entry import DimEntry
from ._enable_all_layers import EnableAllLayers
from ._tensor_info import TensorInfo
if TYPE_CHECKING:
from collections.abc import Callable
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Handle tensor conversion for torch function integration."""
return tensor
class WrappedOperator:
"""
This class wraps PyTorch operations to support first-class dimensions.
"""
def __init__(
self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
):
self.orig = orig
self.wrapper_implementation = wrapper_implementation
self.name = getattr(orig, "__name__", "")
self.doc = getattr(orig, "__doc__", None)
self.dim_name = dim_name
self.is_pointwise = False
self.dim_offset = 0
self.keepdim_offset = 1
self.single_dim = False
self.reduce = True
# Update docstring if we have a dim_name
if self.doc and self.dim_name:
self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
def function(self) -> Callable:
"""Create a wrapped function that calls our wrapper implementation."""
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
return self.wrapper_implementation(self, *args, **kwargs)
# Copy metadata using functools.update_wrapper for just __name__ and __doc__
functools.update_wrapper(
wrapped_func, self.orig, assigned=("__name__",), updated=()
)
wrapped_func.__doc__ = self.doc
return wrapped_func
def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
"""Convert single dimension specification to DimEntry object."""
from . import Dim
if isinstance(dim, Dim):
if keepdim:
raise ValueError("cannot preserve first-class dimensions with keepdim=True")
return DimEntry(dim)
elif isinstance(dim, int):
i = dim
while i >= 0:
i -= ndim
return DimEntry(i)
else:
return DimEntry()
def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
"""Convert dimension specification to list of DimEntry objects."""
de = _wrap_dim(dim, ndim, keepdim)
result = []
if not de.is_none():
result.append(de)
else:
for d in dim:
result.append(_wrap_dim(d, ndim, keepdim))
return result
def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
"""
This is the core method that handles dimension-aware operations.
"""
if not args:
raise ValueError("Expected at least one argument (self)")
# Get dimension argument
dim_arg = kwargs.get(wrapper.dim_name)
if dim_arg is None and wrapper.dim_offset < len(args):
# Try to get dim from positional args (accounting for self at index 0)
dim_idx = wrapper.dim_offset + 1
if dim_idx < len(args):
dim_arg = args[dim_idx]
# If no dimension argument provided, fall back to standard functorch handling
if dim_arg is None:
info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
if not info:
return wrapper.orig(*args, **kwargs)
with EnableAllLayers(info.levels) as guard:
assert info.batchedtensor is not None
guard.inplace_update_layers(info.batchedtensor, info.levels)
new_args = list(args)
new_args[0] = handle_from_tensor(info.batchedtensor)
result = wrapper.orig(*new_args, **kwargs)
return guard.from_batched(result, info.has_device)
# Handle dimension-aware operation
info = TensorInfo.create(args[0])
if not info:
return wrapper.orig(*args, **kwargs)
# Check for keepdim parameter
keepdim = False
if wrapper.reduce:
keepdim_arg = kwargs.get("keepdim")
if keepdim_arg is None and wrapper.keepdim_offset < len(args):
keepdim_idx = wrapper.keepdim_offset + 1
if keepdim_idx < len(args):
keepdim_arg = args[keepdim_idx]
if keepdim_arg is not None:
keepdim = bool(keepdim_arg)
# Wrap dimensions
ndim = info.ndim()
dims = _wrap_dims(dim_arg, ndim, keepdim)
# Convert dimensions to indices and validate
dim_indices: list[int] = []
seen = [False] * len(info.levels)
for d in dims:
midx = None
for i, level in enumerate(info.levels):
if level == d:
midx = i
break
if midx is None:
# Try to match by position/name more flexibly
for i, level in enumerate(info.levels):
if hasattr(level, "matches") and level.matches(d):
midx = i
break
if midx is None:
level_strs = [str(level) for level in info.levels]
raise ValueError(
f"Tensor with dimensions {level_strs} does not contain {d}"
)
seen[midx] = True
dim_indices.append(midx)
# Determine new levels after reduction
new_levels = []
if wrapper.reduce and not keepdim:
for i, level in enumerate(info.levels):
if not seen[i]:
new_levels.append(level)
else:
new_levels = info.levels[:]
# Create dimension indices for the original function
if len(dim_indices) == 1:
py_indices: Any = dim_indices[0]
else:
py_indices = tuple(dim_indices)
# Update arguments
new_args = list(args)
new_kwargs = kwargs.copy()
assert info.tensor is not None
new_args[0] = handle_from_tensor(info.tensor)
# Update dimension argument
if wrapper.dim_name in new_kwargs:
new_kwargs[wrapper.dim_name] = py_indices
else:
dim_idx = wrapper.dim_offset + 1
if dim_idx < len(new_args):
new_args = list(new_args)
new_args[dim_idx] = py_indices
# Call original function
result = wrapper.orig(*new_args, **new_kwargs)
# Wrap results
def wrap_result(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
from . import Tensor
return Tensor.from_positional(obj, new_levels, info.has_device)
return obj
return tree_map(wrap_result, result)
def _wrap(
orig: Callable,
dim_offset: Optional[int] = None,
keepdim_offset: Optional[int] = None,
dim_name: Optional[str] = None,
single_dim: Optional[bool] = None,
reduce: Optional[bool] = None,
) -> Callable:
"""
Wrap a PyTorch function to support first-class dimensions.
Args:
orig: Original function to wrap
dim_offset: Offset for dimension argument (default: 0)
keepdim_offset: Offset for keepdim argument (default: 1)
dim_name: Name of dimension parameter (default: "dim")
single_dim: Whether function takes single dimension (default: False)
reduce: Whether function reduces dimensions (default: True)
"""
dim_name = dim_name or "dim"
wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
if dim_offset is not None:
wrapper.dim_offset = dim_offset
if keepdim_offset is not None:
wrapper.keepdim_offset = keepdim_offset
if single_dim is not None:
wrapper.single_dim = single_dim
if reduce is not None:
wrapper.reduce = reduce
return wrapper.function()
def call_torch_function(
wrapper: WrappedOperator,
func: Callable,
types: tuple,
args: tuple = (),
kwargs: Optional[dict] = None,
) -> Any:
"""
Handle __torch_function__ calls for wrapped operators.
"""
if kwargs is None:
kwargs = {}
# Import here to avoid circular imports
from . import _Tensor
# Use the torch function mechanism from _Tensor
return _Tensor.__torch_function__(func, types, args, kwargs)