mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-27 09:04:53 +08:00 
			
		
		
		
	Remove useless parentheses in `raise` statements if the exception type is raised with no argument. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261 Approved by: https://github.com/albanD
		
			
				
	
	
		
			592 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			592 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: ignore-errors
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import builtins
 | |
| import math
 | |
| import operator
 | |
| from typing import Sequence
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from . import _dtypes, _dtypes_impl, _funcs, _ufuncs, _util
 | |
| from ._normalizations import (
 | |
|     ArrayLike,
 | |
|     normalize_array_like,
 | |
|     normalizer,
 | |
|     NotImplementedType,
 | |
| )
 | |
| 
 | |
| newaxis = None
 | |
| 
 | |
| FLAGS = [
 | |
|     "C_CONTIGUOUS",
 | |
|     "F_CONTIGUOUS",
 | |
|     "OWNDATA",
 | |
|     "WRITEABLE",
 | |
|     "ALIGNED",
 | |
|     "WRITEBACKIFCOPY",
 | |
|     "FNC",
 | |
|     "FORC",
 | |
|     "BEHAVED",
 | |
|     "CARRAY",
 | |
|     "FARRAY",
 | |
| ]
 | |
| 
 | |
| SHORTHAND_TO_FLAGS = {
 | |
|     "C": "C_CONTIGUOUS",
 | |
|     "F": "F_CONTIGUOUS",
 | |
|     "O": "OWNDATA",
 | |
|     "W": "WRITEABLE",
 | |
|     "A": "ALIGNED",
 | |
|     "X": "WRITEBACKIFCOPY",
 | |
|     "B": "BEHAVED",
 | |
|     "CA": "CARRAY",
 | |
|     "FA": "FARRAY",
 | |
| }
 | |
| 
 | |
| 
 | |
| class Flags:
 | |
|     def __init__(self, flag_to_value: dict):
 | |
|         assert all(k in FLAGS for k in flag_to_value.keys())  # sanity check
 | |
|         self._flag_to_value = flag_to_value
 | |
| 
 | |
|     def __getattr__(self, attr: str):
 | |
|         if attr.islower() and attr.upper() in FLAGS:
 | |
|             return self[attr.upper()]
 | |
|         else:
 | |
|             raise AttributeError(f"No flag attribute '{attr}'")
 | |
| 
 | |
|     def __getitem__(self, key):
 | |
|         if key in SHORTHAND_TO_FLAGS.keys():
 | |
|             key = SHORTHAND_TO_FLAGS[key]
 | |
|         if key in FLAGS:
 | |
|             try:
 | |
|                 return self._flag_to_value[key]
 | |
|             except KeyError as e:
 | |
|                 raise NotImplementedError(f"{key=}") from e
 | |
|         else:
 | |
|             raise KeyError(f"No flag key '{key}'")
 | |
| 
 | |
|     def __setattr__(self, attr, value):
 | |
|         if attr.islower() and attr.upper() in FLAGS:
 | |
|             self[attr.upper()] = value
 | |
|         else:
 | |
|             super().__setattr__(attr, value)
 | |
| 
 | |
|     def __setitem__(self, key, value):
 | |
|         if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys():
 | |
|             raise NotImplementedError("Modifying flags is not implemented")
 | |
|         else:
 | |
|             raise KeyError(f"No flag key '{key}'")
 | |
| 
 | |
| 
 | |
| def create_method(fn, name=None):
 | |
|     name = name or fn.__name__
 | |
| 
 | |
|     def f(*args, **kwargs):
 | |
|         return fn(*args, **kwargs)
 | |
| 
 | |
|     f.__name__ = name
 | |
|     f.__qualname__ = f"ndarray.{name}"
 | |
|     return f
 | |
| 
 | |
| 
 | |
| # Map ndarray.name_method -> np.name_func
 | |
| # If name_func == None, it means that name_method == name_func
 | |
| methods = {
 | |
|     "clip": None,
 | |
|     "nonzero": None,
 | |
|     "repeat": None,
 | |
|     "round": None,
 | |
|     "squeeze": None,
 | |
|     "swapaxes": None,
 | |
|     "ravel": None,
 | |
|     # linalg
 | |
|     "diagonal": None,
 | |
|     "dot": None,
 | |
|     "trace": None,
 | |
|     # sorting
 | |
|     "argsort": None,
 | |
|     "searchsorted": None,
 | |
|     # reductions
 | |
|     "argmax": None,
 | |
|     "argmin": None,
 | |
|     "any": None,
 | |
|     "all": None,
 | |
|     "max": None,
 | |
|     "min": None,
 | |
|     "ptp": None,
 | |
|     "sum": None,
 | |
|     "prod": None,
 | |
|     "mean": None,
 | |
|     "var": None,
 | |
|     "std": None,
 | |
|     # scans
 | |
|     "cumsum": None,
 | |
|     "cumprod": None,
 | |
|     # advanced indexing
 | |
|     "take": None,
 | |
|     "choose": None,
 | |
| }
 | |
| 
 | |
| dunder = {
 | |
|     "abs": "absolute",
 | |
|     "invert": None,
 | |
|     "pos": "positive",
 | |
|     "neg": "negative",
 | |
|     "gt": "greater",
 | |
|     "lt": "less",
 | |
|     "ge": "greater_equal",
 | |
|     "le": "less_equal",
 | |
| }
 | |
| 
 | |
| # dunder methods with right-looking and in-place variants
 | |
| ri_dunder = {
 | |
|     "add": None,
 | |
|     "sub": "subtract",
 | |
|     "mul": "multiply",
 | |
|     "truediv": "divide",
 | |
|     "floordiv": "floor_divide",
 | |
|     "pow": "power",
 | |
|     "mod": "remainder",
 | |
|     "and": "bitwise_and",
 | |
|     "or": "bitwise_or",
 | |
|     "xor": "bitwise_xor",
 | |
|     "lshift": "left_shift",
 | |
|     "rshift": "right_shift",
 | |
|     "matmul": None,
 | |
| }
 | |
| 
 | |
| 
 | |
| def _upcast_int_indices(index):
 | |
|     if isinstance(index, torch.Tensor):
 | |
|         if index.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
 | |
|             return index.to(torch.int64)
 | |
|     elif isinstance(index, tuple):
 | |
|         return tuple(_upcast_int_indices(i) for i in index)
 | |
|     return index
 | |
| 
 | |
| 
 | |
| # Used to indicate that a parameter is unspecified (as opposed to explicitly
 | |
| # `None`)
 | |
| class _Unspecified:
 | |
|     pass
 | |
| 
 | |
| 
 | |
| _Unspecified.unspecified = _Unspecified()
 | |
| 
 | |
| ###############################################################
 | |
| #                      ndarray class                          #
 | |
| ###############################################################
 | |
| 
 | |
| 
 | |
| class ndarray:
 | |
|     def __init__(self, t=None):
 | |
|         if t is None:
 | |
|             self.tensor = torch.Tensor()
 | |
|         elif isinstance(t, torch.Tensor):
 | |
|             self.tensor = t
 | |
|         else:
 | |
|             raise ValueError(
 | |
|                 "ndarray constructor is not recommended; prefer"
 | |
|                 "either array(...) or zeros/empty(...)"
 | |
|             )
 | |
| 
 | |
|     # Register NumPy functions as methods
 | |
|     for method, name in methods.items():
 | |
|         fn = getattr(_funcs, name or method)
 | |
|         vars()[method] = create_method(fn, method)
 | |
| 
 | |
|     # Regular methods but coming from ufuncs
 | |
|     conj = create_method(_ufuncs.conjugate, "conj")
 | |
|     conjugate = create_method(_ufuncs.conjugate)
 | |
| 
 | |
|     for method, name in dunder.items():
 | |
|         fn = getattr(_ufuncs, name or method)
 | |
|         method = f"__{method}__"
 | |
|         vars()[method] = create_method(fn, method)
 | |
| 
 | |
|     for method, name in ri_dunder.items():
 | |
|         fn = getattr(_ufuncs, name or method)
 | |
|         plain = f"__{method}__"
 | |
|         vars()[plain] = create_method(fn, plain)
 | |
|         rvar = f"__r{method}__"
 | |
|         vars()[rvar] = create_method(lambda self, other, fn=fn: fn(other, self), rvar)
 | |
|         ivar = f"__i{method}__"
 | |
|         vars()[ivar] = create_method(
 | |
|             lambda self, other, fn=fn: fn(self, other, out=self), ivar
 | |
|         )
 | |
| 
 | |
|     # There's no __idivmod__
 | |
|     __divmod__ = create_method(_ufuncs.divmod, "__divmod__")
 | |
|     __rdivmod__ = create_method(
 | |
|         lambda self, other: _ufuncs.divmod(other, self), "__rdivmod__"
 | |
|     )
 | |
| 
 | |
|     # prevent loop variables leaking into the ndarray class namespace
 | |
|     del ivar, rvar, name, plain, fn, method
 | |
| 
 | |
|     @property
 | |
|     def shape(self):
 | |
|         return tuple(self.tensor.shape)
 | |
| 
 | |
|     @property
 | |
|     def size(self):
 | |
|         return self.tensor.numel()
 | |
| 
 | |
|     @property
 | |
|     def ndim(self):
 | |
|         return self.tensor.ndim
 | |
| 
 | |
|     @property
 | |
|     def dtype(self):
 | |
|         return _dtypes.dtype(self.tensor.dtype)
 | |
| 
 | |
|     @property
 | |
|     def strides(self):
 | |
|         elsize = self.tensor.element_size()
 | |
|         return tuple(stride * elsize for stride in self.tensor.stride())
 | |
| 
 | |
|     @property
 | |
|     def itemsize(self):
 | |
|         return self.tensor.element_size()
 | |
| 
 | |
|     @property
 | |
|     def flags(self):
 | |
|         # Note contiguous in torch is assumed C-style
 | |
|         return Flags(
 | |
|             {
 | |
|                 "C_CONTIGUOUS": self.tensor.is_contiguous(),
 | |
|                 "F_CONTIGUOUS": self.T.tensor.is_contiguous(),
 | |
|                 "OWNDATA": self.tensor._base is None,
 | |
|                 "WRITEABLE": True,  # pytorch does not have readonly tensors
 | |
|             }
 | |
|         )
 | |
| 
 | |
|     @property
 | |
|     def data(self):
 | |
|         return self.tensor.data_ptr()
 | |
| 
 | |
|     @property
 | |
|     def nbytes(self):
 | |
|         return self.tensor.storage().nbytes()
 | |
| 
 | |
|     @property
 | |
|     def T(self):
 | |
|         return self.transpose()
 | |
| 
 | |
|     @property
 | |
|     def real(self):
 | |
|         return _funcs.real(self)
 | |
| 
 | |
|     @real.setter
 | |
|     def real(self, value):
 | |
|         self.tensor.real = asarray(value).tensor
 | |
| 
 | |
|     @property
 | |
|     def imag(self):
 | |
|         return _funcs.imag(self)
 | |
| 
 | |
|     @imag.setter
 | |
|     def imag(self, value):
 | |
|         self.tensor.imag = asarray(value).tensor
 | |
| 
 | |
|     # ctors
 | |
|     def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True):
 | |
|         if order != "K":
 | |
|             raise NotImplementedError(f"astype(..., order={order} is not implemented.")
 | |
|         if casting != "unsafe":
 | |
|             raise NotImplementedError(
 | |
|                 f"astype(..., casting={casting} is not implemented."
 | |
|             )
 | |
|         if not subok:
 | |
|             raise NotImplementedError(f"astype(..., subok={subok} is not implemented.")
 | |
|         if not copy:
 | |
|             raise NotImplementedError(f"astype(..., copy={copy} is not implemented.")
 | |
|         torch_dtype = _dtypes.dtype(dtype).torch_dtype
 | |
|         t = self.tensor.to(torch_dtype)
 | |
|         return ndarray(t)
 | |
| 
 | |
|     @normalizer
 | |
|     def copy(self: ArrayLike, order: NotImplementedType = "C"):
 | |
|         return self.clone()
 | |
| 
 | |
|     @normalizer
 | |
|     def flatten(self: ArrayLike, order: NotImplementedType = "C"):
 | |
|         return torch.flatten(self)
 | |
| 
 | |
|     def resize(self, *new_shape, refcheck=False):
 | |
|         # NB: differs from np.resize: fills with zeros instead of making repeated copies of input.
 | |
|         if refcheck:
 | |
|             raise NotImplementedError(
 | |
|                 f"resize(..., refcheck={refcheck} is not implemented."
 | |
|             )
 | |
|         if new_shape in [(), (None,)]:
 | |
|             return
 | |
| 
 | |
|         # support both x.resize((2, 2)) and x.resize(2, 2)
 | |
|         if len(new_shape) == 1:
 | |
|             new_shape = new_shape[0]
 | |
|         if isinstance(new_shape, int):
 | |
|             new_shape = (new_shape,)
 | |
| 
 | |
|         if builtins.any(x < 0 for x in new_shape):
 | |
|             raise ValueError("all elements of `new_shape` must be non-negative")
 | |
| 
 | |
|         new_numel, old_numel = math.prod(new_shape), self.tensor.numel()
 | |
| 
 | |
|         self.tensor.resize_(new_shape)
 | |
| 
 | |
|         if new_numel >= old_numel:
 | |
|             # zero-fill new elements
 | |
|             assert self.tensor.is_contiguous()
 | |
|             b = self.tensor.flatten()  # does not copy
 | |
|             b[old_numel:].zero_()
 | |
| 
 | |
|     def view(self, dtype=_Unspecified.unspecified, type=_Unspecified.unspecified):
 | |
|         if dtype is _Unspecified.unspecified:
 | |
|             dtype = self.dtype
 | |
|         if type is not _Unspecified.unspecified:
 | |
|             raise NotImplementedError(f"view(..., type={type} is not implemented.")
 | |
|         torch_dtype = _dtypes.dtype(dtype).torch_dtype
 | |
|         tview = self.tensor.view(torch_dtype)
 | |
|         return ndarray(tview)
 | |
| 
 | |
|     @normalizer
 | |
|     def fill(self, value: ArrayLike):
 | |
|         # Both Pytorch and NumPy accept 0D arrays/tensors and scalars, and
 | |
|         # error out on D > 0 arrays
 | |
|         self.tensor.fill_(value)
 | |
| 
 | |
|     def tolist(self):
 | |
|         return self.tensor.tolist()
 | |
| 
 | |
|     def __iter__(self):
 | |
|         return (ndarray(x) for x in self.tensor.__iter__())
 | |
| 
 | |
|     def __str__(self):
 | |
|         return (
 | |
|             str(self.tensor)
 | |
|             .replace("tensor", "torch.ndarray")
 | |
|             .replace("dtype=torch.", "dtype=")
 | |
|         )
 | |
| 
 | |
|     __repr__ = create_method(__str__)
 | |
| 
 | |
|     def __eq__(self, other):
 | |
|         try:
 | |
|             return _ufuncs.equal(self, other)
 | |
|         except (RuntimeError, TypeError):
 | |
|             # Failed to convert other to array: definitely not equal.
 | |
|             falsy = torch.full(self.shape, fill_value=False, dtype=bool)
 | |
|             return asarray(falsy)
 | |
| 
 | |
|     def __ne__(self, other):
 | |
|         return ~(self == other)
 | |
| 
 | |
|     def __index__(self):
 | |
|         try:
 | |
|             return operator.index(self.tensor.item())
 | |
|         except Exception as exc:
 | |
|             raise TypeError(
 | |
|                 "only integer scalar arrays can be converted to a scalar index"
 | |
|             ) from exc
 | |
| 
 | |
|     def __bool__(self):
 | |
|         return bool(self.tensor)
 | |
| 
 | |
|     def __int__(self):
 | |
|         return int(self.tensor)
 | |
| 
 | |
|     def __float__(self):
 | |
|         return float(self.tensor)
 | |
| 
 | |
|     def __complex__(self):
 | |
|         return complex(self.tensor)
 | |
| 
 | |
|     def is_integer(self):
 | |
|         try:
 | |
|             v = self.tensor.item()
 | |
|             result = int(v) == v
 | |
|         except Exception:
 | |
|             result = False
 | |
|         return result
 | |
| 
 | |
|     def __len__(self):
 | |
|         return self.tensor.shape[0]
 | |
| 
 | |
|     def __contains__(self, x):
 | |
|         return self.tensor.__contains__(x)
 | |
| 
 | |
|     def transpose(self, *axes):
 | |
|         # np.transpose(arr, axis=None) but arr.transpose(*axes)
 | |
|         return _funcs.transpose(self, axes)
 | |
| 
 | |
|     def reshape(self, *shape, order="C"):
 | |
|         # arr.reshape(shape) and arr.reshape(*shape)
 | |
|         return _funcs.reshape(self, shape, order=order)
 | |
| 
 | |
|     def sort(self, axis=-1, kind=None, order=None):
 | |
|         # ndarray.sort works in-place
 | |
|         _funcs.copyto(self, _funcs.sort(self, axis, kind, order))
 | |
| 
 | |
|     def item(self, *args):
 | |
|         # Mimic NumPy's implementation with three special cases (no arguments,
 | |
|         # a flat index and a multi-index):
 | |
|         # https://github.com/numpy/numpy/blob/main/numpy/core/src/multiarray/methods.c#L702
 | |
|         if args == ():
 | |
|             return self.tensor.item()
 | |
|         elif len(args) == 1:
 | |
|             # int argument
 | |
|             return self.ravel()[args[0]]
 | |
|         else:
 | |
|             return self.__getitem__(args)
 | |
| 
 | |
|     def __getitem__(self, index):
 | |
|         tensor = self.tensor
 | |
| 
 | |
|         def neg_step(i, s):
 | |
|             if not (isinstance(s, slice) and s.step is not None and s.step < 0):
 | |
|                 return s
 | |
| 
 | |
|             nonlocal tensor
 | |
|             tensor = torch.flip(tensor, (i,))
 | |
| 
 | |
|             # Account for the fact that a slice includes the start but not the end
 | |
|             assert isinstance(s.start, int) or s.start is None
 | |
|             assert isinstance(s.stop, int) or s.stop is None
 | |
|             start = s.stop + 1 if s.stop else None
 | |
|             stop = s.start + 1 if s.start else None
 | |
| 
 | |
|             return slice(start, stop, -s.step)
 | |
| 
 | |
|         if isinstance(index, Sequence):
 | |
|             index = type(index)(neg_step(i, s) for i, s in enumerate(index))
 | |
|         else:
 | |
|             index = neg_step(0, index)
 | |
|         index = _util.ndarrays_to_tensors(index)
 | |
|         index = _upcast_int_indices(index)
 | |
|         return ndarray(tensor.__getitem__(index))
 | |
| 
 | |
|     def __setitem__(self, index, value):
 | |
|         index = _util.ndarrays_to_tensors(index)
 | |
|         index = _upcast_int_indices(index)
 | |
| 
 | |
|         if not _dtypes_impl.is_scalar(value):
 | |
|             value = normalize_array_like(value)
 | |
|             value = _util.cast_if_needed(value, self.tensor.dtype)
 | |
| 
 | |
|         return self.tensor.__setitem__(index, value)
 | |
| 
 | |
|     take = _funcs.take
 | |
|     put = _funcs.put
 | |
| 
 | |
|     def __dlpack__(self, *, stream=None):
 | |
|         return self.tensor.__dlpack__(stream=stream)
 | |
| 
 | |
|     def __dlpack_device__(self):
 | |
|         return self.tensor.__dlpack_device__()
 | |
| 
 | |
| 
 | |
| def _tolist(obj):
 | |
|     """Recursively convert tensors into lists."""
 | |
|     a1 = []
 | |
|     for elem in obj:
 | |
|         if isinstance(elem, (list, tuple)):
 | |
|             elem = _tolist(elem)
 | |
|         if isinstance(elem, ndarray):
 | |
|             a1.append(elem.tensor.tolist())
 | |
|         else:
 | |
|             a1.append(elem)
 | |
|     return a1
 | |
| 
 | |
| 
 | |
| # This is the ideally the only place which talks to ndarray directly.
 | |
| # The rest goes through asarray (preferred) or array.
 | |
| 
 | |
| 
 | |
| def array(obj, dtype=None, *, copy=True, order="K", subok=False, ndmin=0, like=None):
 | |
|     if subok is not False:
 | |
|         raise NotImplementedError("'subok' parameter is not supported.")
 | |
|     if like is not None:
 | |
|         raise NotImplementedError("'like' parameter is not supported.")
 | |
|     if order != "K":
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     # a happy path
 | |
|     if (
 | |
|         isinstance(obj, ndarray)
 | |
|         and copy is False
 | |
|         and dtype is None
 | |
|         and ndmin <= obj.ndim
 | |
|     ):
 | |
|         return obj
 | |
| 
 | |
|     if isinstance(obj, (list, tuple)):
 | |
|         # FIXME and they have the same dtype, device, etc
 | |
|         if obj and all(isinstance(x, torch.Tensor) for x in obj):
 | |
|             # list of arrays: *under torch.Dynamo* these are FakeTensors
 | |
|             obj = torch.stack(obj)
 | |
|         else:
 | |
|             # XXX: remove tolist
 | |
|             # lists of ndarrays: [1, [2, 3], ndarray(4)] convert to lists of lists
 | |
|             obj = _tolist(obj)
 | |
| 
 | |
|     # is obj an ndarray already?
 | |
|     if isinstance(obj, ndarray):
 | |
|         obj = obj.tensor
 | |
| 
 | |
|     # is a specific dtype requested?
 | |
|     torch_dtype = None
 | |
|     if dtype is not None:
 | |
|         torch_dtype = _dtypes.dtype(dtype).torch_dtype
 | |
| 
 | |
|     tensor = _util._coerce_to_tensor(obj, torch_dtype, copy, ndmin)
 | |
|     return ndarray(tensor)
 | |
| 
 | |
| 
 | |
| def asarray(a, dtype=None, order="K", *, like=None):
 | |
|     return array(a, dtype=dtype, order=order, like=like, copy=False, ndmin=0)
 | |
| 
 | |
| 
 | |
| def ascontiguousarray(a, dtype=None, *, like=None):
 | |
|     arr = asarray(a, dtype=dtype, like=like)
 | |
|     if not arr.tensor.is_contiguous():
 | |
|         arr.tensor = arr.tensor.contiguous()
 | |
|     return arr
 | |
| 
 | |
| 
 | |
| def from_dlpack(x, /):
 | |
|     t = torch.from_dlpack(x)
 | |
|     return ndarray(t)
 | |
| 
 | |
| 
 | |
| def _extract_dtype(entry):
 | |
|     try:
 | |
|         dty = _dtypes.dtype(entry)
 | |
|     except Exception:
 | |
|         dty = asarray(entry).dtype
 | |
|     return dty
 | |
| 
 | |
| 
 | |
| def can_cast(from_, to, casting="safe"):
 | |
|     from_ = _extract_dtype(from_)
 | |
|     to_ = _extract_dtype(to)
 | |
| 
 | |
|     return _dtypes_impl.can_cast_impl(from_.torch_dtype, to_.torch_dtype, casting)
 | |
| 
 | |
| 
 | |
| def result_type(*arrays_and_dtypes):
 | |
|     tensors = []
 | |
|     for entry in arrays_and_dtypes:
 | |
|         try:
 | |
|             t = asarray(entry).tensor
 | |
|         except (RuntimeError, ValueError, TypeError):
 | |
|             dty = _dtypes.dtype(entry)
 | |
|             t = torch.empty(1, dtype=dty.torch_dtype)
 | |
|         tensors.append(t)
 | |
| 
 | |
|     torch_dtype = _dtypes_impl.result_type_impl(*tensors)
 | |
|     return _dtypes.dtype(torch_dtype)
 |