mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is follow-up of #165214 to continue applying ruff UP035 rule to the code base. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165515 Approved by: https://github.com/Lucaskabela
268 lines
8.1 KiB
Python
268 lines
8.1 KiB
Python
"""
|
|
Python implementation of function wrapping functionality for functorch.dim.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import functools
|
|
from typing import Any, Optional, TYPE_CHECKING
|
|
|
|
import torch
|
|
from torch.utils._pytree import tree_map
|
|
|
|
from ._dim_entry import DimEntry
|
|
from ._enable_all_layers import EnableAllLayers
|
|
from ._tensor_info import TensorInfo
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Callable
|
|
|
|
|
|
def handle_from_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
|
"""Handle tensor conversion for torch function integration."""
|
|
return tensor
|
|
|
|
|
|
class WrappedOperator:
|
|
"""
|
|
This class wraps PyTorch operations to support first-class dimensions.
|
|
"""
|
|
|
|
def __init__(
|
|
self, orig: Callable, wrapper_implementation: Callable, dim_name: str = "dim"
|
|
):
|
|
self.orig = orig
|
|
self.wrapper_implementation = wrapper_implementation
|
|
self.name = getattr(orig, "__name__", "")
|
|
self.doc = getattr(orig, "__doc__", None)
|
|
self.dim_name = dim_name
|
|
|
|
self.is_pointwise = False
|
|
self.dim_offset = 0
|
|
self.keepdim_offset = 1
|
|
self.single_dim = False
|
|
self.reduce = True
|
|
|
|
# Update docstring if we have a dim_name
|
|
if self.doc and self.dim_name:
|
|
self.doc = f"{self.doc}\nArgument '{self.dim_name}' can be either an integer or a torchdim.Dim object.\n"
|
|
|
|
def function(self) -> Callable:
|
|
"""Create a wrapped function that calls our wrapper implementation."""
|
|
|
|
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
|
return self.wrapper_implementation(self, *args, **kwargs)
|
|
|
|
# Copy metadata using functools.update_wrapper for just __name__ and __doc__
|
|
functools.update_wrapper(
|
|
wrapped_func, self.orig, assigned=("__name__",), updated=()
|
|
)
|
|
wrapped_func.__doc__ = self.doc
|
|
|
|
return wrapped_func
|
|
|
|
|
|
def _wrap_dim(dim: Any, ndim: int, keepdim: bool = False) -> DimEntry:
|
|
"""Convert single dimension specification to DimEntry object."""
|
|
from . import Dim
|
|
|
|
if isinstance(dim, Dim):
|
|
if keepdim:
|
|
raise ValueError("cannot preserve first-class dimensions with keepdim=True")
|
|
return DimEntry(dim)
|
|
elif isinstance(dim, int):
|
|
i = dim
|
|
while i >= 0:
|
|
i -= ndim
|
|
return DimEntry(i)
|
|
else:
|
|
return DimEntry()
|
|
|
|
|
|
def _wrap_dims(dim: Any, ndim: int, keepdim: bool = False) -> list[DimEntry]:
|
|
"""Convert dimension specification to list of DimEntry objects."""
|
|
de = _wrap_dim(dim, ndim, keepdim)
|
|
result = []
|
|
if not de.is_none():
|
|
result.append(de)
|
|
else:
|
|
for d in dim:
|
|
result.append(_wrap_dim(d, ndim, keepdim))
|
|
return result
|
|
|
|
|
|
def patched_dim_method(wrapper: WrappedOperator, *args: Any, **kwargs: Any) -> Any:
|
|
"""
|
|
This is the core method that handles dimension-aware operations.
|
|
"""
|
|
if not args:
|
|
raise ValueError("Expected at least one argument (self)")
|
|
|
|
# Get dimension argument
|
|
dim_arg = kwargs.get(wrapper.dim_name)
|
|
if dim_arg is None and wrapper.dim_offset < len(args):
|
|
# Try to get dim from positional args (accounting for self at index 0)
|
|
dim_idx = wrapper.dim_offset + 1
|
|
if dim_idx < len(args):
|
|
dim_arg = args[dim_idx]
|
|
|
|
# If no dimension argument provided, fall back to standard functorch handling
|
|
if dim_arg is None:
|
|
info = TensorInfo.create(args[0], ensure_batched=True, ensure_present=False)
|
|
if not info:
|
|
return wrapper.orig(*args, **kwargs)
|
|
|
|
with EnableAllLayers(info.levels) as guard:
|
|
assert info.batchedtensor is not None
|
|
guard.inplace_update_layers(info.batchedtensor, info.levels)
|
|
new_args = list(args)
|
|
new_args[0] = handle_from_tensor(info.batchedtensor)
|
|
result = wrapper.orig(*new_args, **kwargs)
|
|
return guard.from_batched(result, info.has_device)
|
|
|
|
# Handle dimension-aware operation
|
|
info = TensorInfo.create(args[0])
|
|
if not info:
|
|
return wrapper.orig(*args, **kwargs)
|
|
|
|
# Check for keepdim parameter
|
|
keepdim = False
|
|
if wrapper.reduce:
|
|
keepdim_arg = kwargs.get("keepdim")
|
|
if keepdim_arg is None and wrapper.keepdim_offset < len(args):
|
|
keepdim_idx = wrapper.keepdim_offset + 1
|
|
if keepdim_idx < len(args):
|
|
keepdim_arg = args[keepdim_idx]
|
|
if keepdim_arg is not None:
|
|
keepdim = bool(keepdim_arg)
|
|
|
|
# Wrap dimensions
|
|
ndim = info.ndim()
|
|
dims = _wrap_dims(dim_arg, ndim, keepdim)
|
|
|
|
# Convert dimensions to indices and validate
|
|
dim_indices: list[int] = []
|
|
seen = [False] * len(info.levels)
|
|
|
|
for d in dims:
|
|
midx = None
|
|
for i, level in enumerate(info.levels):
|
|
if level == d:
|
|
midx = i
|
|
break
|
|
|
|
if midx is None:
|
|
# Try to match by position/name more flexibly
|
|
for i, level in enumerate(info.levels):
|
|
if hasattr(level, "matches") and level.matches(d):
|
|
midx = i
|
|
break
|
|
|
|
if midx is None:
|
|
level_strs = [str(level) for level in info.levels]
|
|
raise ValueError(
|
|
f"Tensor with dimensions {level_strs} does not contain {d}"
|
|
)
|
|
|
|
seen[midx] = True
|
|
dim_indices.append(midx)
|
|
|
|
# Determine new levels after reduction
|
|
new_levels = []
|
|
if wrapper.reduce and not keepdim:
|
|
for i, level in enumerate(info.levels):
|
|
if not seen[i]:
|
|
new_levels.append(level)
|
|
else:
|
|
new_levels = info.levels[:]
|
|
|
|
# Create dimension indices for the original function
|
|
if len(dim_indices) == 1:
|
|
py_indices: Any = dim_indices[0]
|
|
else:
|
|
py_indices = tuple(dim_indices)
|
|
|
|
# Update arguments
|
|
new_args = list(args)
|
|
new_kwargs = kwargs.copy()
|
|
assert info.tensor is not None
|
|
new_args[0] = handle_from_tensor(info.tensor)
|
|
|
|
# Update dimension argument
|
|
if wrapper.dim_name in new_kwargs:
|
|
new_kwargs[wrapper.dim_name] = py_indices
|
|
else:
|
|
dim_idx = wrapper.dim_offset + 1
|
|
if dim_idx < len(new_args):
|
|
new_args = list(new_args)
|
|
new_args[dim_idx] = py_indices
|
|
|
|
# Call original function
|
|
result = wrapper.orig(*new_args, **new_kwargs)
|
|
|
|
# Wrap results
|
|
def wrap_result(obj: Any) -> Any:
|
|
if isinstance(obj, torch.Tensor):
|
|
from . import Tensor
|
|
|
|
return Tensor.from_positional(obj, new_levels, info.has_device)
|
|
return obj
|
|
|
|
return tree_map(wrap_result, result)
|
|
|
|
|
|
def _wrap(
|
|
orig: Callable,
|
|
dim_offset: Optional[int] = None,
|
|
keepdim_offset: Optional[int] = None,
|
|
dim_name: Optional[str] = None,
|
|
single_dim: Optional[bool] = None,
|
|
reduce: Optional[bool] = None,
|
|
) -> Callable:
|
|
"""
|
|
Wrap a PyTorch function to support first-class dimensions.
|
|
|
|
Args:
|
|
orig: Original function to wrap
|
|
dim_offset: Offset for dimension argument (default: 0)
|
|
keepdim_offset: Offset for keepdim argument (default: 1)
|
|
dim_name: Name of dimension parameter (default: "dim")
|
|
single_dim: Whether function takes single dimension (default: False)
|
|
reduce: Whether function reduces dimensions (default: True)
|
|
"""
|
|
dim_name = dim_name or "dim"
|
|
|
|
wrapper = WrappedOperator(orig, patched_dim_method, dim_name)
|
|
|
|
if dim_offset is not None:
|
|
wrapper.dim_offset = dim_offset
|
|
if keepdim_offset is not None:
|
|
wrapper.keepdim_offset = keepdim_offset
|
|
if single_dim is not None:
|
|
wrapper.single_dim = single_dim
|
|
if reduce is not None:
|
|
wrapper.reduce = reduce
|
|
|
|
return wrapper.function()
|
|
|
|
|
|
def call_torch_function(
|
|
wrapper: WrappedOperator,
|
|
func: Callable,
|
|
types: tuple,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict] = None,
|
|
) -> Any:
|
|
"""
|
|
Handle __torch_function__ calls for wrapped operators.
|
|
"""
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
# Import here to avoid circular imports
|
|
from . import _Tensor
|
|
|
|
# Use the torch function mechanism from _Tensor
|
|
return _Tensor.__torch_function__(func, types, args, kwargs)
|