mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #141563 In NumPy, an ellipsis always acts as a separator between advanced indices, even when the ellipsis doesn't actually match any dimensions. In PyTorch an empty ellipsis doesn't cause a separation. This leads to differing behavior between Numpy and PyTorch in this edge case. This difference in behavior leads to a bug when using torch.compile: ```python >>> import numpy as np >>> f = lambda x: x[:,(0,1),...,(0,1)].shape >>> a = np.ones((3, 4, 5)) >>> f(a) (2, 3) >>> torch.compile(f)(a) (3, 2) ``` Similarly to #157676, this PR doesn't change PyTorch's behavior, but it fixes the translation layer, ensuring torch._numpy compatibility with NumPy. I am marking this PR as fixing #141563, even though PyTorch behavior isn't modified. Notice that there are still some other bugs in PyTorch's advanced indexing, that need to be fixed (mainly regarding proper accounting of dimensions when multidimensional boolean masks are present). But those need to be fixed at the ATen operator level. Examples: - #71673 - #107699 - #158125 Pull Request resolved: https://github.com/pytorch/pytorch/pull/158297 Approved by: https://github.com/soumith
721 lines
21 KiB
Python
721 lines
21 KiB
Python
# mypy: ignore-errors
|
|
|
|
from __future__ import annotations
|
|
|
|
import builtins
|
|
import math
|
|
import operator
|
|
from collections.abc import Sequence
|
|
|
|
import torch
|
|
|
|
from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
|
|
from ._normalizations import (
|
|
ArrayLike,
|
|
normalize_array_like,
|
|
normalizer,
|
|
NotImplementedType,
|
|
)
|
|
|
|
|
|
newaxis = None
|
|
|
|
FLAGS = [
|
|
"C_CONTIGUOUS",
|
|
"F_CONTIGUOUS",
|
|
"OWNDATA",
|
|
"WRITEABLE",
|
|
"ALIGNED",
|
|
"WRITEBACKIFCOPY",
|
|
"FNC",
|
|
"FORC",
|
|
"BEHAVED",
|
|
"CARRAY",
|
|
"FARRAY",
|
|
]
|
|
|
|
SHORTHAND_TO_FLAGS = {
|
|
"C": "C_CONTIGUOUS",
|
|
"F": "F_CONTIGUOUS",
|
|
"O": "OWNDATA",
|
|
"W": "WRITEABLE",
|
|
"A": "ALIGNED",
|
|
"X": "WRITEBACKIFCOPY",
|
|
"B": "BEHAVED",
|
|
"CA": "CARRAY",
|
|
"FA": "FARRAY",
|
|
}
|
|
|
|
|
|
class Flags:
|
|
def __init__(self, flag_to_value: dict):
|
|
assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check
|
|
self._flag_to_value = flag_to_value
|
|
|
|
def __getattr__(self, attr: str):
|
|
if attr.islower() and attr.upper() in FLAGS:
|
|
return self[attr.upper()]
|
|
else:
|
|
raise AttributeError(f"No flag attribute '{attr}'")
|
|
|
|
def __getitem__(self, key):
|
|
if key in SHORTHAND_TO_FLAGS.keys():
|
|
key = SHORTHAND_TO_FLAGS[key]
|
|
if key in FLAGS:
|
|
try:
|
|
return self._flag_to_value[key]
|
|
except KeyError as e:
|
|
raise NotImplementedError(f"{key=}") from e
|
|
else:
|
|
raise KeyError(f"No flag key '{key}'")
|
|
|
|
def __setattr__(self, attr, value):
|
|
if attr.islower() and attr.upper() in FLAGS:
|
|
self[attr.upper()] = value
|
|
else:
|
|
super().__setattr__(attr, value)
|
|
|
|
def __setitem__(self, key, value):
|
|
if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
|
|
raise NotImplementedError("Modifying flags is not implemented")
|
|
else:
|
|
raise KeyError(f"No flag key '{key}'")
|
|
|
|
|
|
def create_method(fn, name=None):
|
|
name = name or fn.__name__
|
|
|
|
def f(*args, **kwargs):
|
|
return fn(*args, **kwargs)
|
|
|
|
f.__name__ = name
|
|
f.__qualname__ = f"ndarray.{name}"
|
|
return f
|
|
|
|
|
|
# Map ndarray.name_method -> np.name_func
|
|
# If name_func == None, it means that name_method == name_func
|
|
methods = {
|
|
"clip": None,
|
|
"nonzero": None,
|
|
"repeat": None,
|
|
"round": None,
|
|
"squeeze": None,
|
|
"swapaxes": None,
|
|
"ravel": None,
|
|
# linalg
|
|
"diagonal": None,
|
|
"dot": None,
|
|
"trace": None,
|
|
# sorting
|
|
"argsort": None,
|
|
"searchsorted": None,
|
|
# reductions
|
|
"argmax": None,
|
|
"argmin": None,
|
|
"any": None,
|
|
"all": None,
|
|
"max": None,
|
|
"min": None,
|
|
"ptp": None,
|
|
"sum": None,
|
|
"prod": None,
|
|
"mean": None,
|
|
"var": None,
|
|
"std": None,
|
|
# scans
|
|
"cumsum": None,
|
|
"cumprod": None,
|
|
# advanced indexing
|
|
"take": None,
|
|
"choose": None,
|
|
}
|
|
|
|
dunder = {
|
|
"abs": "absolute",
|
|
"invert": None,
|
|
"pos": "positive",
|
|
"neg": "negative",
|
|
"gt": "greater",
|
|
"lt": "less",
|
|
"ge": "greater_equal",
|
|
"le": "less_equal",
|
|
}
|
|
|
|
# dunder methods with right-looking and in-place variants
|
|
ri_dunder = {
|
|
"add": None,
|
|
"sub": "subtract",
|
|
"mul": "multiply",
|
|
"truediv": "divide",
|
|
"floordiv": "floor_divide",
|
|
"pow": "power",
|
|
"mod": "remainder",
|
|
"and": "bitwise_and",
|
|
"or": "bitwise_or",
|
|
"xor": "bitwise_xor",
|
|
"lshift": "left_shift",
|
|
"rshift": "right_shift",
|
|
"matmul": None,
|
|
}
|
|
|
|
|
|
def _upcast_int_indices(index):
|
|
if isinstance(index, torch.Tensor):
|
|
if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
|
|
return index.to(torch.int64)
|
|
elif isinstance(index, tuple):
|
|
return tuple(_upcast_int_indices(i) for i in index)
|
|
return index
|
|
|
|
|
|
def _has_advanced_indexing(index):
|
|
"""Check if there's any advanced indexing"""
|
|
return any(
|
|
isinstance(idx, (Sequence, bool))
|
|
or (isinstance(idx, torch.Tensor) and (idx.dtype == torch.bool or idx.ndim > 0))
|
|
for idx in index
|
|
)
|
|
|
|
|
|
def _numpy_compatible_indexing(index):
|
|
"""Convert scalar indices to lists when advanced indexing is present for NumPy compatibility."""
|
|
if not isinstance(index, tuple):
|
|
index = (index,)
|
|
|
|
# Check if there's any advanced indexing (sequences, booleans, or tensors)
|
|
has_advanced = _has_advanced_indexing(index)
|
|
|
|
if not has_advanced:
|
|
return index
|
|
|
|
# Convert integer scalar indices to single-element lists when advanced indexing is present
|
|
# Note: Do NOT convert boolean scalars (True/False) as they have special meaning in NumPy
|
|
converted = []
|
|
for idx in index:
|
|
if isinstance(idx, int) and not isinstance(idx, bool):
|
|
# Integer scalars should be converted to lists
|
|
converted.append([idx])
|
|
elif (
|
|
isinstance(idx, torch.Tensor)
|
|
and idx.ndim == 0
|
|
and not torch.is_floating_point(idx)
|
|
and idx.dtype != torch.bool
|
|
):
|
|
# Zero-dimensional tensors holding integers should be treated the same as integer scalars
|
|
converted.append([idx])
|
|
else:
|
|
# Everything else (booleans, lists, slices, etc.) stays as is
|
|
converted.append(idx)
|
|
|
|
return tuple(converted)
|
|
|
|
|
|
def _get_bool_depth(s):
|
|
"""Returns the depth of a boolean sequence/tensor"""
|
|
if isinstance(s, bool):
|
|
return True, 0
|
|
if isinstance(s, torch.Tensor) and s.dtype == torch.bool:
|
|
return True, s.ndim
|
|
if not (isinstance(s, Sequence) and s and s[0] != s):
|
|
return False, 0
|
|
is_bool, depth = _get_bool_depth(s[0])
|
|
return is_bool, depth + 1
|
|
|
|
|
|
def _numpy_empty_ellipsis_patch(index, tensor_ndim):
|
|
"""
|
|
Patch for NumPy-compatible ellipsis behavior when ellipsis doesn't match any dimensions.
|
|
|
|
In NumPy, when an ellipsis (...) doesn't actually match any dimensions of the input array,
|
|
it still acts as a separator between advanced indices. PyTorch doesn't have this behavior.
|
|
|
|
This function detects when we have:
|
|
1. Advanced indexing on both sides of an ellipsis
|
|
2. The ellipsis doesn't actually match any dimensions
|
|
"""
|
|
if not isinstance(index, tuple):
|
|
index = (index,)
|
|
|
|
# Find ellipsis position
|
|
ellipsis_pos = None
|
|
for i, idx in enumerate(index):
|
|
if idx is Ellipsis:
|
|
ellipsis_pos = i
|
|
break
|
|
|
|
# If no ellipsis, no patch needed
|
|
if ellipsis_pos is None:
|
|
return index, lambda x: x, lambda x: x
|
|
|
|
# Count non-ellipsis dimensions consumed by the index
|
|
consumed_dims = 0
|
|
for idx in index:
|
|
is_bool, depth = _get_bool_depth(idx)
|
|
if is_bool:
|
|
consumed_dims += depth
|
|
elif idx is Ellipsis or idx is None:
|
|
continue
|
|
else:
|
|
consumed_dims += 1
|
|
|
|
# Calculate how many dimensions the ellipsis should match
|
|
ellipsis_dims = tensor_ndim - consumed_dims
|
|
|
|
# Check if ellipsis doesn't match any dimensions
|
|
if ellipsis_dims == 0:
|
|
# Check if we have advanced indexing on both sides of ellipsis
|
|
left_advanced = _has_advanced_indexing(index[:ellipsis_pos])
|
|
right_advanced = _has_advanced_indexing(index[ellipsis_pos + 1 :])
|
|
|
|
if left_advanced and right_advanced:
|
|
# This is the case where NumPy and PyTorch differ
|
|
# We need to ensure the advanced indices are treated as separated
|
|
new_index = index[:ellipsis_pos] + (None,) + index[ellipsis_pos + 1 :]
|
|
end_ndims = 1 + sum(
|
|
1 for idx in index[ellipsis_pos + 1 :] if isinstance(idx, slice)
|
|
)
|
|
|
|
def squeeze_fn(x):
|
|
return x.squeeze(-end_ndims)
|
|
|
|
def unsqueeze_fn(x):
|
|
if isinstance(x, torch.Tensor) and x.ndim >= end_ndims:
|
|
return x.unsqueeze(-end_ndims)
|
|
return x
|
|
|
|
return new_index, squeeze_fn, unsqueeze_fn
|
|
|
|
return index, lambda x: x, lambda x: x
|
|
|
|
|
|
# Used to indicate that a parameter is unspecified (as opposed to explicitly
|
|
# `None`)
|
|
class _Unspecified:
|
|
pass
|
|
|
|
|
|
_Unspecified.unspecified = _Unspecified()
|
|
|
|
###############################################################
|
|
# ndarray class #
|
|
###############################################################
|
|
|
|
|
|
class ndarray:
|
|
def __init__(self, t=None):
|
|
if t is None:
|
|
self.tensor = torch.Tensor()
|
|
elif isinstance(t, torch.Tensor):
|
|
self.tensor = t
|
|
else:
|
|
raise ValueError(
|
|
"ndarray constructor is not recommended; prefer"
|
|
"either array(...) or zeros/empty(...)"
|
|
)
|
|
|
|
# Register NumPy functions as methods
|
|
for method, name in methods.items():
|
|
fn = getattr(_funcs, name or method)
|
|
vars()[method] = create_method(fn, method)
|
|
|
|
# Regular methods but coming from ufuncs
|
|
conj = create_method(_ufuncs.conjugate, "conj")
|
|
conjugate = create_method(_ufuncs.conjugate)
|
|
|
|
for method, name in dunder.items():
|
|
fn = getattr(_ufuncs, name or method)
|
|
method = f"__{method}__"
|
|
vars()[method] = create_method(fn, method)
|
|
|
|
for method, name in ri_dunder.items():
|
|
fn = getattr(_ufuncs, name or method)
|
|
plain = f"__{method}__"
|
|
vars()[plain] = create_method(fn, plain)
|
|
rvar = f"__r{method}__"
|
|
vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
|
|
ivar = f"__i{method}__"
|
|
vars()[ivar] = create_method(
|
|
lambda self, other, fn=fn: fn(self, other, out=self), ivar
|
|
)
|
|
|
|
# There's no __idivmod__
|
|
__divmod__ = create_method(_ufuncs.divmod, "__divmod__")
|
|
__rdivmod__ = create_method(
|
|
lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
|
|
)
|
|
|
|
# prevent loop variables leaking into the ndarray class namespace
|
|
del ivar, rvar, name, plain, fn, method
|
|
|
|
@property
|
|
def shape(self):
|
|
return tuple(self.tensor.shape)
|
|
|
|
@property
|
|
def size(self):
|
|
return self.tensor.numel()
|
|
|
|
@property
|
|
def ndim(self):
|
|
return self.tensor.ndim
|
|
|
|
@property
|
|
def dtype(self):
|
|
return _dtypes.dtype(self.tensor.dtype)
|
|
|
|
@property
|
|
def strides(self):
|
|
elsize = self.tensor.element_size()
|
|
return tuple(stride * elsize for stride in self.tensor.stride())
|
|
|
|
@property
|
|
def itemsize(self):
|
|
return self.tensor.element_size()
|
|
|
|
@property
|
|
def flags(self):
|
|
# Note contiguous in torch is assumed C-style
|
|
return Flags(
|
|
{
|
|
"C_CONTIGUOUS": self.tensor.is_contiguous(),
|
|
"F_CONTIGUOUS": self.T.tensor.is_contiguous(),
|
|
"OWNDATA": self.tensor._base is None,
|
|
"WRITEABLE": True, # pytorch does not have readonly tensors
|
|
}
|
|
)
|
|
|
|
@property
|
|
def data(self):
|
|
return self.tensor.data_ptr()
|
|
|
|
@property
|
|
def nbytes(self):
|
|
return self.tensor.storage().nbytes()
|
|
|
|
@property
|
|
def T(self):
|
|
return self.transpose()
|
|
|
|
@property
|
|
def real(self):
|
|
return _funcs.real(self)
|
|
|
|
@real.setter
|
|
def real(self, value):
|
|
self.tensor.real = asarray(value).tensor
|
|
|
|
@property
|
|
def imag(self):
|
|
return _funcs.imag(self)
|
|
|
|
@imag.setter
|
|
def imag(self, value):
|
|
self.tensor.imag = asarray(value).tensor
|
|
|
|
# ctors
|
|
def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
|
|
if order != "K":
|
|
raise NotImplementedError(f"astype(..., order={order} is not implemented.")
|
|
if casting != "unsafe":
|
|
raise NotImplementedError(
|
|
f"astype(..., casting={casting} is not implemented."
|
|
)
|
|
if not subok:
|
|
raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
|
|
if not copy:
|
|
raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
|
|
torch_dtype = _dtypes.dtype(dtype).torch_dtype
|
|
t = self.tensor.to(torch_dtype)
|
|
return ndarray(t)
|
|
|
|
@normalizer
|
|
def copy(self: ArrayLike, order: NotImplementedType = "C"):
|
|
return self.clone()
|
|
|
|
@normalizer
|
|
def flatten(self: ArrayLike, order: NotImplementedType = "C"):
|
|
return torch.flatten(self)
|
|
|
|
def resize(self, *new_shape, refcheck=False):
|
|
# NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
|
|
if refcheck:
|
|
raise NotImplementedError(
|
|
f"resize(..., refcheck={refcheck} is not implemented."
|
|
)
|
|
if new_shape in [(), (None,)]:
|
|
return
|
|
|
|
# support both x.resize((2, 2)) and x.resize(2, 2)
|
|
if len(new_shape) == 1:
|
|
new_shape = new_shape[0]
|
|
if isinstance(new_shape, int):
|
|
new_shape = (new_shape,)
|
|
|
|
if builtins.any(x < 0 for x in new_shape):
|
|
raise ValueError("all elements of `new_shape` must be non-negative")
|
|
|
|
new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
|
|
|
|
self.tensor.resize_(new_shape)
|
|
|
|
if new_numel >= old_numel:
|
|
# zero-fill new elements
|
|
assert self.tensor.is_contiguous()
|
|
b = self.tensor.flatten() # does not copy
|
|
b[old_numel:].zero_()
|
|
|
|
def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
|
|
if dtype is _Unspecified.unspecified:
|
|
dtype = self.dtype
|
|
if type is not _Unspecified.unspecified:
|
|
raise NotImplementedError(f"view(..., type={type} is not implemented.")
|
|
torch_dtype = _dtypes.dtype(dtype).torch_dtype
|
|
tview = self.tensor.view(torch_dtype)
|
|
return ndarray(tview)
|
|
|
|
@normalizer
|
|
def fill(self, value: ArrayLike):
|
|
# Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
|
|
# error out on D > 0 arrays
|
|
self.tensor.fill_(value)
|
|
|
|
def tolist(self):
|
|
return self.tensor.tolist()
|
|
|
|
def __iter__(self):
|
|
return (ndarray(x) for x in self.tensor.__iter__())
|
|
|
|
def __str__(self):
|
|
return (
|
|
str(self.tensor)
|
|
.replace("tensor", "torch.ndarray")
|
|
.replace("dtype=torch.", "dtype=")
|
|
)
|
|
|
|
__repr__ = create_method(__str__)
|
|
|
|
def __eq__(self, other):
|
|
try:
|
|
return _ufuncs.equal(self, other)
|
|
except (RuntimeError, TypeError):
|
|
# Failed to convert other to array: definitely not equal.
|
|
falsy = torch.full(self.shape, fill_value=False, dtype=bool)
|
|
return asarray(falsy)
|
|
|
|
def __ne__(self, other):
|
|
return ~(self == other)
|
|
|
|
def __index__(self):
|
|
try:
|
|
return operator.index(self.tensor.item())
|
|
except Exception as exc:
|
|
raise TypeError(
|
|
"only integer scalar arrays can be converted to a scalar index"
|
|
) from exc
|
|
|
|
def __bool__(self):
|
|
return bool(self.tensor)
|
|
|
|
def __int__(self):
|
|
return int(self.tensor)
|
|
|
|
def __float__(self):
|
|
return float(self.tensor)
|
|
|
|
def __complex__(self):
|
|
return complex(self.tensor)
|
|
|
|
def is_integer(self):
|
|
try:
|
|
v = self.tensor.item()
|
|
result = int(v) == v
|
|
except Exception:
|
|
result = False
|
|
return result
|
|
|
|
def __len__(self):
|
|
return self.tensor.shape[0]
|
|
|
|
def __contains__(self, x):
|
|
return self.tensor.__contains__(x)
|
|
|
|
def transpose(self, *axes):
|
|
# np.transpose(arr, axis=None) but arr.transpose(*axes)
|
|
return _funcs.transpose(self, axes)
|
|
|
|
def reshape(self, *shape, order="C"):
|
|
# arr.reshape(shape) and arr.reshape(*shape)
|
|
return _funcs.reshape(self, shape, order=order)
|
|
|
|
def sort(self, axis=-1, kind=None, order=None):
|
|
# ndarray.sort works in-place
|
|
_funcs.copyto(self, _funcs.sort(self, axis, kind, order))
|
|
|
|
def item(self, *args):
|
|
# Mimic NumPy's implementation with three special cases (no arguments,
|
|
# a flat index and a multi-index):
|
|
# https://github.com/numpy/numpy/blob/main/numpy/_core/src/multiarray/methods.c#L702
|
|
if args == ():
|
|
return self.tensor.item()
|
|
elif len(args) == 1:
|
|
# int argument
|
|
return self.ravel()[args[0]]
|
|
else:
|
|
return self.__getitem__(args)
|
|
|
|
def __getitem__(self, index):
|
|
tensor = self.tensor
|
|
|
|
def neg_step(i, s):
|
|
if not (isinstance(s, slice) and s.step is not None and s.step < 0):
|
|
return s
|
|
|
|
nonlocal tensor
|
|
tensor = torch.flip(tensor, (i,))
|
|
|
|
# Account for the fact that a slice includes the start but not the end
|
|
assert isinstance(s.start, int) or s.start is None
|
|
assert isinstance(s.stop, int) or s.stop is None
|
|
start = s.stop + 1 if s.stop else None
|
|
stop = s.start + 1 if s.start else None
|
|
|
|
return slice(start, stop, -s.step)
|
|
|
|
if isinstance(index, Sequence):
|
|
index = type(index)(neg_step(i, s) for i, s in enumerate(index))
|
|
else:
|
|
index = neg_step(0, index)
|
|
index = _util.ndarrays_to_tensors(index)
|
|
index = _upcast_int_indices(index)
|
|
# Apply NumPy-compatible indexing conversion
|
|
index = _numpy_compatible_indexing(index)
|
|
# Apply NumPy-compatible empty ellipsis behavior
|
|
index, maybe_squeeze, _ = _numpy_empty_ellipsis_patch(index, tensor.ndim)
|
|
return maybe_squeeze(ndarray(tensor.__getitem__(index)))
|
|
|
|
def __setitem__(self, index, value):
|
|
index = _util.ndarrays_to_tensors(index)
|
|
index = _upcast_int_indices(index)
|
|
# Apply NumPy-compatible indexing conversion
|
|
index = _numpy_compatible_indexing(index)
|
|
# Apply NumPy-compatible empty ellipsis behavior
|
|
index, _, maybe_unsqueeze = _numpy_empty_ellipsis_patch(index, self.tensor.ndim)
|
|
|
|
if not _dtypes_impl.is_scalar(value):
|
|
value = normalize_array_like(value)
|
|
value = _util.cast_if_needed(value, self.tensor.dtype)
|
|
|
|
return self.tensor.__setitem__(index, maybe_unsqueeze(value))
|
|
|
|
take = _funcs.take
|
|
put = _funcs.put
|
|
|
|
def __dlpack__(self, *, stream=None):
|
|
return self.tensor.__dlpack__(stream=stream)
|
|
|
|
def __dlpack_device__(self):
|
|
return self.tensor.__dlpack_device__()
|
|
|
|
|
|
def _tolist(obj):
|
|
"""Recursively convert tensors into lists."""
|
|
a1 = []
|
|
for elem in obj:
|
|
if isinstance(elem, (list, tuple)):
|
|
elem = _tolist(elem)
|
|
if isinstance(elem, ndarray):
|
|
a1.append(elem.tensor.tolist())
|
|
else:
|
|
a1.append(elem)
|
|
return a1
|
|
|
|
|
|
# This is the ideally the only place which talks to ndarray directly.
|
|
# The rest goes through asarray (preferred) or array.
|
|
|
|
|
|
def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
|
|
if subok is not False:
|
|
raise NotImplementedError("'subok' parameter is not supported.")
|
|
if like is not None:
|
|
raise NotImplementedError("'like' parameter is not supported.")
|
|
if order != "K":
|
|
raise NotImplementedError
|
|
|
|
# a happy path
|
|
if (
|
|
isinstance(obj, ndarray)
|
|
and copy is False
|
|
and dtype is None
|
|
and ndmin <= obj.ndim
|
|
):
|
|
return obj
|
|
|
|
if isinstance(obj, (list, tuple)):
|
|
# FIXME and they have the same dtype, device, etc
|
|
if obj and all(isinstance(x, torch.Tensor) for x in obj):
|
|
# list of arrays: *under torch.Dynamo* these are FakeTensors
|
|
obj = torch.stack(obj)
|
|
else:
|
|
# XXX: remove tolist
|
|
# lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
|
|
obj = _tolist(obj)
|
|
|
|
# is obj an ndarray already?
|
|
if isinstance(obj, ndarray):
|
|
obj = obj.tensor
|
|
|
|
# is a specific dtype requested?
|
|
torch_dtype = None
|
|
if dtype is not None:
|
|
torch_dtype = _dtypes.dtype(dtype).torch_dtype
|
|
|
|
tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
|
|
return ndarray(tensor)
|
|
|
|
|
|
def asarray(a, dtype=None, order="K", *, like=None):
|
|
return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
|
|
|
|
|
|
def ascontiguousarray(a, dtype=None, *, like=None):
|
|
arr = asarray(a, dtype=dtype, like=like)
|
|
if not arr.tensor.is_contiguous():
|
|
arr.tensor = arr.tensor.contiguous()
|
|
return arr
|
|
|
|
|
|
def from_dlpack(x, /):
|
|
t = torch.from_dlpack(x)
|
|
return ndarray(t)
|
|
|
|
|
|
def _extract_dtype(entry):
|
|
try:
|
|
dty = _dtypes.dtype(entry)
|
|
except Exception:
|
|
dty = asarray(entry).dtype
|
|
return dty
|
|
|
|
|
|
def can_cast(from_, to, casting="safe"):
|
|
from_ = _extract_dtype(from_)
|
|
to_ = _extract_dtype(to)
|
|
|
|
return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
|
|
|
|
|
|
def result_type(*arrays_and_dtypes):
|
|
tensors = []
|
|
for entry in arrays_and_dtypes:
|
|
try:
|
|
t = asarray(entry).tensor
|
|
except (RuntimeError, ValueError, TypeError):
|
|
dty = _dtypes.dtype(entry)
|
|
t = torch.empty(1, dtype=dty.torch_dtype)
|
|
tensors.append(t)
|
|
|
|
torch_dtype = _dtypes_impl.result_type_impl(*tensors)
|
|
return _dtypes.dtype(torch_dtype)
|