diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index f52d417d2ba2..95747181e848 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -24,10 +24,6 @@ from . import op_properties # use dict to avoid writing C++ bindings for set pointwise = dict.fromkeys(op_properties.pointwise, True) -use_c = True -if not use_c: - from . import reference - class _Tensor: # fast path around slow wrapping/unwrapping logic for simply queries used @@ -40,12 +36,8 @@ class _Tensor: def dim(self): return self.ndim - if use_c: - __torch_function__ = classmethod(_C.__torch_function__) - expand = _C._instancemethod(_C.expand) - else: - __torch_function__ = reference.__torch_function__ - expand = reference.expand + __torch_function__ = classmethod(_C.__torch_function__) + expand = _C._instancemethod(_C.expand) index = _C._instancemethod(_C.index) @@ -64,8 +56,6 @@ class Dim(_C.Dim, _Tensor): class Tensor(_Tensor, _C.Tensor): - if not use_c: - from_batched = staticmethod(_C.Tensor_from_batched) from_positional = staticmethod(_C.Tensor_from_positional) sum = _C._instancemethod(_C.Tensor_sum) @@ -75,21 +65,17 @@ def cat(tensors, dim, new_dim): return stack(tensors, n, dim).index([n, dim], new_dim) -if use_c: - _wrap = _C._wrap +_wrap = _C._wrap - def _def(name, *args, **kwargs): - orig = getattr(torch.Tensor, name) - setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) - t__getitem__ = _C._instancemethod(_C.__getitem__) - stack = _C.stack - split = _C._instancemethod(_C.split) -else: - _wrap, _def = reference._wrap, reference._def - t__getitem__ = reference.t__getitem__ - stack = reference.stack - split = reference.split +def _def(name, *args, **kwargs): + orig = getattr(torch.Tensor, name) + setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs))) + + +t__getitem__ = _C._instancemethod(_C.__getitem__) +stack = _C.stack +split = _C._instancemethod(_C.split) # note: there is no python reference t__setitem__ = _C._instancemethod(_C.__setitem__) @@ -105,13 +91,10 @@ torch.Tensor.split = split _Tensor.split = split torch.Tensor.expand = _C._instancemethod(_C.expand) torch.Tensor.index = _C._instancemethod(_C.index) -wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__) +wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__) del _Tensor.ndim -if use_c: - _Tensor.order = _C._instancemethod(_C.order) -else: - _Tensor.order = reference.positional +_Tensor.order = _C._instancemethod(_C.order) _def("mean") _def("sum") diff --git a/functorch/dim/batch_tensor.py b/functorch/dim/batch_tensor.py deleted file mode 100644 index dae9b270896e..000000000000 --- a/functorch/dim/batch_tensor.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from contextlib import contextmanager - -from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers - - -_enabled = False - - -@contextmanager -def _enable_layers(dims): - global _enabled - assert not _enabled - input = sorted((d._level, d.size) for d in dims if not isinstance(d, int)) - n = len(input) - try: - _vmap_add_layers(input) - _enabled = True - yield - finally: - _enabled = False - _vmap_remove_layers(n) diff --git a/functorch/dim/delayed_mul_tensor.py b/functorch/dim/delayed_mul_tensor.py deleted file mode 100644 index 3c136cfe1247..000000000000 --- a/functorch/dim/delayed_mul_tensor.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import torch - -from . import _Tensor, Tensor -from .reference import _dims, _enable_layers, llist, ltuple - - -class DelayedMulTensor(_Tensor): - def __init__(self, lhs, rhs): - self._lhs, self._rhs = lhs, rhs - self._data = None - self._levels_data = None - self._has_device = lhs._has_device or rhs._has_device - self._batchtensor_data = None - self._tensor_data = None - - @property - def _levels(self): - if self._levels_data is None: - levels = llist(self._lhs._levels) - for l in self._rhs._levels: - if l not in levels: - levels.append(l) - self._levels_data = ltuple(levels) - return self._levels_data - - @property - def _batchtensor(self): - if self._batchtensor_data is None: - with _enable_layers(self._levels): - print("bt multiply fallback") - self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor - return self._batchtensor_data - - @property - def _tensor(self): - if self._tensor_data is None: - self._tensor_data = Tensor.from_batched( - self._batchtensor, self._has_device - )._tensor - return self._tensor_data - - @property - def ndim(self): - return self._batchtensor.ndim - - @property - def dims(self): - return ltuple(super().dims) - - def sum(self, dim): - dims = _dims(dim, 0, False, False) - n = ord("a") - all_levels = self._levels - - def to_char(d): - return chr(n + all_levels.index(d)) - - plhs, levelslhs = self._lhs._tensor, self._lhs._levels - prhs, levelsrhs = self._rhs._tensor, self._rhs._levels - new_levels = [l for l in self._levels if l not in dims] - fmt = "".join( - [ - *(to_char(d) for d in levelslhs), - ",", - *(to_char(d) for d in levelsrhs), - "->", - *(to_char(d) for d in new_levels), - ] - ) - result_data = torch.einsum(fmt, (plhs, prhs)) - return Tensor.from_positional(result_data, new_levels, True) diff --git a/functorch/dim/dim.py b/functorch/dim/dim.py deleted file mode 100644 index 9a4b56866484..000000000000 --- a/functorch/dim/dim.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import dis -import inspect -from dataclasses import dataclass -from typing import Union - -from . import DimList - - -_vmap_levels = [] - - -@dataclass -class LevelInfo: - level: int - alive: bool = True - - -class Dim: - def __init__(self, name: str, size: Union[None, int] = None): - self.name = name - self._size = None - self._vmap_level = None - if size is not None: - self.size = size - - def __del__(self): - if self._vmap_level is not None: - _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821 - while ( - not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level # noqa: F821 - ): - _vmap_decrement_nesting() # noqa: F821 - _vmap_levels.pop() - - @property - def size(self): - assert self.is_bound - return self._size - - @size.setter - def size(self, size: int): - from . import DimensionBindError - - if self._size is None: - self._size = size - self._vmap_level = _vmap_increment_nesting(size, "same") # noqa: F821 - self._vmap_stack = len(_vmap_levels) - _vmap_levels.append(LevelInfo(self._vmap_level)) - - elif self._size != size: - raise DimensionBindError( - f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}" - ) - - @property - def is_bound(self): - return self._size is not None - - def __repr__(self): - return self.name - - -def extract_name(inst): - assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME" - return inst.argval - - -_cache = {} - - -def dims(lists=0): - frame = inspect.currentframe() - assert frame is not None - calling_frame = frame.f_back - assert calling_frame is not None - code, lasti = calling_frame.f_code, calling_frame.f_lasti - key = (code, lasti) - if key not in _cache: - first = lasti // 2 + 1 - instructions = list(dis.get_instructions(calling_frame.f_code)) - unpack = instructions[first] - - if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME": - # just a single dim, not a list - name = unpack.argval - ctor = Dim if lists == 0 else DimList - _cache[key] = lambda: ctor(name=name) - else: - assert unpack.opname == "UNPACK_SEQUENCE" - ndims = unpack.argval - names = tuple( - extract_name(instructions[first + 1 + i]) for i in range(ndims) - ) - first_list = len(names) - lists - _cache[key] = lambda: tuple( - Dim(n) if i < first_list else DimList(name=n) - for i, n in enumerate(names) - ) - return _cache[key]() - - -def _dim_set(positional, arg): - def convert(a): - if isinstance(a, Dim): - return a - else: - assert isinstance(a, int) - return positional[a] - - if arg is None: - return positional - elif not isinstance(arg, (Dim, int)): - return tuple(convert(a) for a in arg) - else: - return (convert(arg),) diff --git a/functorch/dim/reference.py b/functorch/dim/reference.py deleted file mode 100644 index fd934011d823..000000000000 --- a/functorch/dim/reference.py +++ /dev/null @@ -1,645 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# reference python implementations for C ops -import torch -from functorch._C import dim as _C - -from . import op_properties -from .batch_tensor import _enable_layers -from .tree_map import tree_flatten, tree_map - - -DimList = _C.DimList -import operator -from functools import reduce - - -# use dict to avoid writing C++ bindings for set -pointwise = set(op_properties.pointwise) - - -def prod(x): - return reduce(operator.mul, x, 1) - - -def _wrap_dim(d, N, keepdim): - from . import Dim - - if isinstance(d, Dim): - assert not keepdim, "cannot preserve first-class dimensions with keepdim=True" - return d - elif d >= 0: - return d - N - else: - return d - - -def _dims(d, N, keepdim, single_dim): - from . import Dim - - if isinstance(d, (Dim, int)): - return ltuple((_wrap_dim(d, N, keepdim),)) - assert not single_dim, f"expected a single dimension or int but found: {d}" - return ltuple(_wrap_dim(x, N, keepdim) for x in d) - - -def _bind_dims_to_size(lhs_size, rhs, lhs_debug): - from . import DimensionMismatchError - - not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound) - if len(not_bound) == 1: - idx, d = not_bound[0] - rhs_so_far = prod(r.size for r in rhs if r.is_bound) - if lhs_size % rhs_so_far != 0: - rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) - raise DimensionMismatchError( - f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}" - ) - new_size = lhs_size // rhs_so_far - d.size = new_size - elif len(not_bound) > 1: - rhs_s = tuple("?" if not r.is_bound else str(r.size) for r in rhs) - raise DimensionMismatchError( - f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}" - ) - else: - rhs_size = prod(r.size for r in rhs) - if lhs_size != rhs_size: - raise DimensionMismatchError( - f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}" - ) - - -def _tensor_levels(inp): - from . import _Tensor - - if isinstance(inp, _Tensor): - return inp._tensor, llist(inp._levels), inp._has_device - else: - return inp, llist(range(-inp.ndim, 0)), True - - -def _match_levels(v, from_levels, to_levels): - view = [] - permute = [] - requires_view = False - size = v.size() - for t in to_levels: - try: - idx = from_levels.index(t) - permute.append(idx) - view.append(size[idx]) - except ValueError: - view.append(1) - requires_view = True - if permute != list(range(len(permute))): - v = v.permute(*permute) - if requires_view: - v = v.view(*view) - return v - - -# make a single dimension positional but do not permute it, -# used to do multi-tensor operators where the dim being acted on -# should not physically move if possible -def _positional_no_permute(self, dim, expand_dim=False): - from . import Tensor - - ptensor, levels = self._tensor, llist(self._levels) - try: - idx = levels.index(dim) - except ValueError: - if not expand_dim: - raise - idx = 0 - ptensor = ptensor.expand(dim.size, *ptensor.size()) - levels.insert(0, 0) - idx_batched = 0 - for i in range(idx): - if isinstance(levels[i], int): - levels[i] -= 1 - idx_batched += 1 - levels[idx] = -idx_batched - 1 - return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched - - -def seq(a, b): - from . import Dim - - if isinstance(a, Dim) != isinstance(b, Dim): - return False - if isinstance(a, Dim): - return a is b - else: - return a == b - - -class isin: - __slots__ = () - - def __contains__(self, item): - for x in self: - if seq(item, x): - return True - return False - - def index(self, item): - for i, x in enumerate(self): - if seq(item, x): - return i - raise ValueError - - -class llist(isin, list): - __slots__ = () - - -class ltuple(isin, tuple): - __slots__ = () - - -empty_dict = {} - - -@classmethod -def __torch_function__(self, orig, cls, args, kwargs=empty_dict): - from . import _Tensor, Tensor, TensorLike - from .delayed_mul_tensor import DelayedMulTensor - - if orig is torch.Tensor.__mul__: - lhs, rhs = args - if ( - isinstance(lhs, _Tensor) - and isinstance(rhs, _Tensor) - and lhs.ndim == 0 - and rhs.ndim == 0 - ): - return DelayedMulTensor(lhs, rhs) - all_dims = llist() - flat_args, unflatten = tree_flatten((args, kwargs)) - device_holding_tensor = None - for f in flat_args: - if isinstance(f, _Tensor): - if f._has_device: - device_holding_tensor = f._batchtensor - for d in f.dims: - if d not in all_dims: - all_dims.append(d) - - def unwrap(t): - if isinstance(t, _Tensor): - r = t._batchtensor - if device_holding_tensor is not None and not t._has_device: - r = r.to(device=device_holding_tensor.device) - return r - return t - - if orig in pointwise: - result_levels = llist() - to_expand = [] - for i, f in enumerate(flat_args): - if isinstance(f, TensorLike): - ptensor, levels, _ = _tensor_levels(f) - if ( - isinstance(f, _Tensor) - and not f._has_device - and device_holding_tensor is not None - ): - ptensor = ptensor.to(device=device_holding_tensor.device) - flat_args[i] = ptensor - for l in levels: - if l not in result_levels: - result_levels.append(l) - to_expand.append((i, levels)) - - for i, levels in to_expand: - flat_args[i] = _match_levels(flat_args[i], levels, result_levels) - args, kwargs = unflatten(flat_args) - result = orig(*args, **kwargs) - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_positional( - t, result_levels, device_holding_tensor is not None - ) - return t - - return tree_map(wrap, result) - else: - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_batched(t, device_holding_tensor is not None) - return t - - with _enable_layers(all_dims): - print(f"batch_tensor for {orig}") - args, kwargs = unflatten(unwrap(f) for f in flat_args) - result = orig(*args, **kwargs) - # print("END", orig) - return tree_map(wrap, result) - - -def positional(self, *dims): - from . import Dim, DimensionBindError, Tensor - - ptensor, levels = self._tensor, llist(self._levels) - flat_dims = llist() - view = [] - needs_view = False - ndim = self.ndim - for d in dims: - if isinstance(d, DimList): - flat_dims.extend(d) - view.extend(e.size for e in d) - elif isinstance(d, Dim): - flat_dims.append(d) - view.append(d.size) - elif isinstance(d, int): - d = _wrap_dim(d, ndim, False) - flat_dims.append(d) - view.append(ptensor.size(d)) - else: - flat_dims.extend(d) - view.append(prod(e.size for e in d)) - needs_view = True - - permute = list(range(len(levels))) - for i, d in enumerate(flat_dims): - try: - idx = levels.index(d) - except ValueError as e: - raise DimensionBindError( - f"tensor of dimensions {self.dims} does not contain dim {d}" - ) from e - p = permute[idx] - del levels[idx] - del permute[idx] - levels.insert(i, 0) - permute.insert(i, p) - ptensor = ptensor.permute(*permute) - seen = 0 - for i in range(len(levels) - 1, -1, -1): - if isinstance(levels[i], int): - seen += 1 - levels[i] = -seen - result = Tensor.from_positional(ptensor, levels, self._has_device) - if needs_view: - result = result.reshape(*view, *result.size()[len(flat_dims) :]) - return result - - -def _contains_dim(input): - from . import Dim - - for i in input: - if isinstance(i, Dim): - return True - - -def expand(self, *sizes): - if not _contains_dim(sizes): - return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes)) - dims = sizes - sizes = [d.size for d in dims] + [-1] * self.ndim - self = self.expand(*sizes) - return self[dims] - - -_not_present = object() - - -def _getarg(name, offset, args, kwargs, default): - if len(args) > offset: - return args[offset] - return kwargs.get(name, default) - - -def _patcharg(name, offset, args, kwargs, value): - if len(args) > offset: - args[offset] = value - else: - kwargs[name] = value - - -def _wrap( - orig, dim_offset=0, keepdim_offset=1, dim_name="dim", single_dim=False, reduce=True -): - from . import Dim, Tensor, TensorLike - - def fn(self, *args, **kwargs): - dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present) - if dim is _not_present or (single_dim and not isinstance(dim, Dim)): - with _enable_layers(self.dims): - print(f"dim fallback batch_tensor for {orig}") - return Tensor.from_batched( - orig(self._batchtensor, *args, **kwargs), self._has_device - ) - keepdim = ( - _getarg("keepdim", keepdim_offset, args, kwargs, False) if reduce else False - ) - t, levels = self._tensor, llist(self._levels) - dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim) - dim_indices = tuple(levels.index(d) for d in dims) - if reduce and not keepdim: - new_levels = [l for i, l in enumerate(levels) if i not in dim_indices] - else: - new_levels = levels - - if len(dim_indices) == 1: - dim_indices = dim_indices[ - 0 - ] # so that dims that really only take a single argument work... - args = list(args) - _patcharg(dim_name, dim_offset, args, kwargs, dim_indices) - - def wrap(t): - if isinstance(t, TensorLike): - return Tensor.from_positional(t, new_levels, self._has_device) - return t - - with _enable_layers(new_levels): - print(f"dim used batch_tensor for {orig}") - r = orig(t, *args, **kwargs) - return tree_map(wrap, r) - - return fn - - -def _def(name, *args, **kwargs): - from . import _Tensor - - orig = getattr(torch.Tensor, name) - setattr(_Tensor, name, _wrap(orig, *args, **kwargs)) - - -no_slice = slice(None) - -_orig_getitem = torch.Tensor.__getitem__ - - -class dim_tracker: - def __init__(self) -> None: - self.dims = llist() - self.count = [] - - def record(self, d): - if d not in self.dims: - self.dims.append(d) - self.count.append(1) - - def __getitem__(self, d): - return self.count[self.dims.index(d)] - - -def t__getitem__(self, input): - from . import _Tensor, Dim, DimensionBindError, DimList, Tensor, TensorLike - - # * bail to original example if we have a single non-Dim tensor, or a non-tensor - # * locate ... or an unbound tensor list, and determine its size, bind dim list - # (remember that None does not count to the total dim count) - # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim, - # produce the re-view if needed - # * for each single-use dim index, replace with no_slice and mark that it will be added - # (keep track of whether we have to call super) - # * call super if needed - # * if we have dims to bind, bind them (it will help if we eliminated ... and None before) - # this handles bool indexing handling, as well as some other simple cases. - - is_simple = ( - not isinstance(input, Dim) - and not isinstance(input, (tuple, list)) - and - # WAR for functorch bug where zero time tensors in getitem are not handled correctly. - not (isinstance(input, TensorLike) and input.ndim == 0) - ) - - if is_simple: - if isinstance(self, _Tensor): - return _Tensor.__torch_function__(_orig_getitem, None, (self, input)) - else: - return _orig_getitem(self, input) - - # can further optimize this case - if not isinstance(input, tuple): - input = [input] - else: - input = list(input) - - dims_indexed = 0 - expanding_object = None - dimlists = [] - for i, s in enumerate(input): - if s is ... or isinstance(s, DimList) and not s.is_bound: - if expanding_object is not None: - msg = ( - "at most one ... or unbound dimension list can exist in indexing list but" - f" found 2 at offsets {i} and {expanding_object}" - ) - raise DimensionBindError(msg) - expanding_object = i - - if isinstance(s, DimList): - dims_indexed += len(s) if s.is_bound else 0 - dimlists.append(i) - elif s is not None and s is not ...: - dims_indexed += 1 - - ndim = self.ndim - if dims_indexed > ndim: - raise IndexError( - f"at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions." - ) - if expanding_object is not None: - expanding_ndims = ndim - dims_indexed - obj = input[expanding_object] - if obj is ...: - input[expanding_object : expanding_object + 1] = [ - no_slice - ] * expanding_ndims - else: - obj.bind_len(expanding_ndims) - # flatten the dimslists into the indexing - for i in reversed(dimlists): - input[i : i + 1] = input[i] - dims_indexed = 0 - requires_view = False - size = self.size() - view_sizes = [] - dims_seen = dim_tracker() - - def add_dims(t): - if not isinstance(t, _Tensor): - return - for d in t.dims: - dims_seen.record(d) - - add_dims(self) - dim_packs = [] - for i, idx in enumerate(input): - if idx is None: - input[i] = no_slice - view_sizes.append(1) - requires_view = True - else: - sz = size[dims_indexed] - if isinstance(idx, Dim): - idx.size = sz - dims_seen.record(idx) - view_sizes.append(sz) - elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim): - for d in idx: - dims_seen.record(idx) - _bind_dims_to_size(sz, idx, f"offset {i}") - view_sizes.extend(d.size for d in idx) - requires_view = True - dim_packs.append(i) - else: - add_dims(idx) - view_sizes.append(sz) - dims_indexed += 1 - if requires_view: - self = self.view(*view_sizes) - for i in reversed(dim_packs): - input[i : i + 1] = input[i] - - # currently: - # input is flat, containing either Dim, or Tensor, or something valid for standard indexing - # self may have first-class dims as well. - - # to index: - # drop the first class dims from self, they just become direct indices of their positions - - # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index. - # these dimensions will appear and need to be bound at the first place tensor occurs - - if isinstance(self, _Tensor): - ptensor_self, levels = self._tensor, list(self._levels) - # indices to ptensor rather than self which has first-class dimensions - input_it = iter(input) - flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels] - has_device = self._has_device - to_pad = 0 - else: - ptensor_self, flat_inputs = self, input - to_pad = ptensor_self.ndim - len(flat_inputs) - has_device = True - - result_levels = [] - index_levels = [] - tensor_insert_point = None - to_expand = {} - requires_getindex = False - for i, inp in enumerate(flat_inputs): - if isinstance(inp, Dim) and dims_seen[inp] == 1: - flat_inputs[i] = no_slice - result_levels.append(inp) - elif isinstance(inp, TensorLike): - requires_getindex = True - if tensor_insert_point is None: - tensor_insert_point = len(result_levels) - ptensor, levels, _ = _tensor_levels(inp) - to_expand[i] = levels - flat_inputs[i] = ptensor - for l in levels: - if l not in index_levels: - index_levels.append(l) - else: - requires_getindex = True - result_levels.append(0) - - if tensor_insert_point is not None: - result_levels[tensor_insert_point:tensor_insert_point] = index_levels - - for i, levels in to_expand.items(): - flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels) - - if requires_getindex: - result = _orig_getitem(ptensor_self, flat_inputs) - else: - result = ptensor_self - - next_positional = -1 - if to_pad > 0: - result_levels.extend([0] * to_pad) - for i, r in enumerate(reversed(result_levels)): - if isinstance(r, int): - result_levels[-1 - i] = next_positional - next_positional -= 1 - - return Tensor.from_positional(result, result_levels, has_device) - - -# XXX - dim is optional and can be the outer-most dimension... -def stack(tensors, new_dim, dim=0, out=None): - if isinstance(dim, int): - return torch.stack(tensors, dim, out).index(dim, new_dim) - index = None - if out is not None: - out, index = _positional_no_permute(out, dim, expand_dim=True) - ptensors = [] - for t in tensors: - pt, pi = _positional_no_permute(t, dim, expand_dim=True) - if index is not None and pi != index: - pt = pt.move_dim(pi, index) - else: - index = pi - ptensors.append(pt) - pr = torch.stack(ptensors, index, out=out) - return pr.index((index, index + 1), (new_dim, dim)) - - -_orig_split = torch.Tensor.split - - -def split(self, split_size_or_sections, dim=0): - from . import _Tensor, Dim - - if isinstance(split_size_or_sections, int) or any( - isinstance(t, int) for t in split_size_or_sections - ): - if isinstance(dim, Dim): - raise ValueError( - "when dim is specified as a Dim object, split sizes must also be dimensions." - ) - return _orig_split(self, split_size_or_sections, dim=dim) - - if isinstance(dim, Dim): - assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}" - self, dim = _positional_no_permute(self, dim) - - size = self.size(dim) - total_bound_size = 0 - unbound = [] - sizes = [] - for i, d in enumerate(split_size_or_sections): - if d.is_bound: - sizes.append(d.size) - total_bound_size += d.size - else: - sizes.append(0) - unbound.append(i) - - if unbound: - assert total_bound_size <= size, ( - f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})" - ) - remaining_size = size - total_bound_size - chunk_size = -(-remaining_size // len(unbound)) - for u in unbound: - sz = min(chunk_size, remaining_size) - split_size_or_sections[u].size = sz - sizes[u] = sz - remaining_size -= sz - else: - assert total_bound_size == size, ( - f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})" - ) - return tuple( - t.index(dim, d) - for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)) - ) diff --git a/functorch/dim/wrap_type.py b/functorch/dim/wrap_type.py index aae543b91a89..b9ebda47c4cf 100644 --- a/functorch/dim/wrap_type.py +++ b/functorch/dim/wrap_type.py @@ -26,18 +26,8 @@ FUNC_TYPES = ( PROPERTY_TYPES = (GetSetDescriptorType, property) -def _py_wrap_method(orig, __torch_function__): - def impl(*args, **kwargs): - return __torch_function__(orig, None, args, kwargs) - - return impl - - -def wrap_type(use_c, to_patch, pattern, __torch_function__): - if use_c: - wrap_method = _wrap_method - else: - wrap_method = _py_wrap_method +def wrap_type(to_patch, pattern, __torch_function__): + wrap_method = _wrap_method all = {} for t in reversed(pattern.mro()[:-1]): # skip object