mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Use `sym_eq` to check equality on tuple of ints/symints ### DDE ``` torch._dynamo.exc.UserError: Could not guard on data-dependent expression Eq(u0, u1) (unhinted: Eq(u0, u1)). (Size-like symbols: u1, u0) Caused by: return torch.nn.functional.layer_norm( # test/inductor/test_unbacked_symints.py:527 in fn (_refs/__init__.py:3292 in native_layer_norm) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/160683 Approved by: https://github.com/bobrenjc93
		
			
				
	
	
		
			6724 lines
		
	
	
		
			212 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			6724 lines
		
	
	
		
			212 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-decorators
 | |
| # mypy: allow-untyped-defs
 | |
| import builtins
 | |
| import collections
 | |
| import inspect
 | |
| import itertools
 | |
| import math
 | |
| import operator
 | |
| import warnings
 | |
| from collections.abc import Iterable, Sequence
 | |
| from enum import Enum
 | |
| from functools import partial, reduce, singledispatch, wraps
 | |
| from typing import Any, Callable, cast, Optional, overload, Union
 | |
| 
 | |
| import torch
 | |
| import torch._prims as prims
 | |
| import torch._prims_common as utils
 | |
| import torch.utils._pytree as pytree
 | |
| from torch import sym_float, sym_int
 | |
| from torch._prims_common import (
 | |
|     BoolLike,
 | |
|     contiguous_for_memory_format_or_false,
 | |
|     DeviceLikeType,
 | |
|     Dim,
 | |
|     DimsSequenceType,
 | |
|     DimsType,
 | |
|     dtype_to_type,
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND,
 | |
|     FloatLike,
 | |
|     FloatWithoutSymFloat,
 | |
|     IntLike,
 | |
|     is_contiguous_or_false,
 | |
|     is_weakly_lesser_type,
 | |
|     Number,
 | |
|     NumberType,
 | |
|     RealNumberType,
 | |
|     REDUCTION_OUTPUT_TYPE_KIND,
 | |
|     ShapeType,
 | |
|     StrideType,
 | |
|     TensorLike,
 | |
|     TensorLikeType,
 | |
|     TensorOrNumberLikeType,
 | |
|     TensorSequenceType,
 | |
| )
 | |
| from torch._prims_common.wrappers import (
 | |
|     _maybe_convert_to_dtype,
 | |
|     _maybe_resize_out,
 | |
|     _safe_copy_out,
 | |
|     elementwise_type_promotion_wrapper,
 | |
|     elementwise_unary_scalar_wrapper,
 | |
|     out_wrapper,
 | |
| )
 | |
| 
 | |
| 
 | |
| # Experimental module containing prototype Python references for existing
 | |
| #   PyTorch operations.
 | |
| 
 | |
| __all__ = [
 | |
|     #
 | |
|     # Elementwise Unary References
 | |
|     #
 | |
|     "abs",
 | |
|     "acos",
 | |
|     "acosh",
 | |
|     "asinh",
 | |
|     "asin",
 | |
|     "atan",
 | |
|     "atanh",
 | |
|     "bitwise_not",
 | |
|     # "cbrt",  # No corresponding torch operation
 | |
|     "ceil",
 | |
|     "conj_physical",
 | |
|     "cos",
 | |
|     "cosh",
 | |
|     "count_nonzero",
 | |
|     "deg2rad",
 | |
|     "digamma",
 | |
|     "erf",
 | |
|     "erfinv",
 | |
|     "erfc",
 | |
|     "exp",
 | |
|     "expm1",
 | |
|     "exponential",
 | |
|     "exp2",
 | |
|     "fill",
 | |
|     "fill_",
 | |
|     "floor",
 | |
|     "frac",
 | |
|     "geometric",
 | |
|     "index_add",
 | |
|     "index_copy",
 | |
|     "index_copy_",
 | |
|     "index_select",
 | |
|     "index_fill",
 | |
|     "index_fill_",
 | |
|     "isfinite",
 | |
|     "isinf",
 | |
|     "isposinf",
 | |
|     "isneginf",
 | |
|     "isnan",
 | |
|     "isreal",
 | |
|     "i0",
 | |
|     "lerp",
 | |
|     "lgamma",
 | |
|     "log",
 | |
|     "log1p",
 | |
|     "log2",
 | |
|     "log10",
 | |
|     "log_normal",
 | |
|     "log_softmax",
 | |
|     "mvlgamma",
 | |
|     "norm",
 | |
|     "normal",
 | |
|     "nan_to_num",
 | |
|     "neg",
 | |
|     "positive",
 | |
|     "rad2deg",
 | |
|     "reciprocal",
 | |
|     "round",  # TODO: model kwargs
 | |
|     "sigmoid",
 | |
|     "sgn",
 | |
|     "sign",
 | |
|     "signbit",
 | |
|     "sin",
 | |
|     "sinc",
 | |
|     "sinh",
 | |
|     "softmax",
 | |
|     "sqrt",
 | |
|     "square",
 | |
|     "tan",
 | |
|     "tanh",
 | |
|     "trace",
 | |
|     "trunc",
 | |
|     #
 | |
|     # Elementwise Binary References
 | |
|     #
 | |
|     "add",
 | |
|     "atan2",
 | |
|     "bitwise_and",
 | |
|     "bitwise_left_shift",
 | |
|     "bitwise_or",
 | |
|     "bitwise_right_shift",
 | |
|     "bitwise_xor",
 | |
|     "clamp_min",
 | |
|     "clamp_max",
 | |
|     "copysign",
 | |
|     "div",
 | |
|     "eq",
 | |
|     "float_power",
 | |
|     "floor_divide",
 | |
|     "fmax",
 | |
|     "fmin",
 | |
|     "fmod",
 | |
|     "gcd",
 | |
|     "ge",
 | |
|     "gt",
 | |
|     "heaviside",
 | |
|     "hypot",
 | |
|     "igamma",
 | |
|     "igammac",
 | |
|     "imag",
 | |
|     "isclose",
 | |
|     "lcm",
 | |
|     # 'ldexp',
 | |
|     "le",
 | |
|     "logaddexp",
 | |
|     "logaddexp2",
 | |
|     "logical_and",
 | |
|     "logical_not",
 | |
|     "logical_or",
 | |
|     "logical_xor",
 | |
|     "logsumexp",
 | |
|     "lt",
 | |
|     # 'max', # implement with reductions
 | |
|     "maximum",
 | |
|     # 'min', # implement with reductions
 | |
|     "minimum",
 | |
|     "mul",
 | |
|     "ne",
 | |
|     "nextafter",
 | |
|     # 'polar',  # abs, cos, sin
 | |
|     "pow",
 | |
|     "real",
 | |
|     "rpow",
 | |
|     "remainder",
 | |
|     "rsub",
 | |
|     "rtruediv",
 | |
|     "rfloordiv",
 | |
|     "sub",
 | |
|     "true_divide",
 | |
|     "trunc_divide",
 | |
|     "xlogy",
 | |
|     #
 | |
|     # Elementwise Ternary References
 | |
|     #
 | |
|     "addcdiv",
 | |
|     "addcmul",
 | |
|     "clamp",
 | |
|     #
 | |
|     # Conditional references
 | |
|     #
 | |
|     "masked_fill",
 | |
|     "masked_fill_",
 | |
|     "where",
 | |
|     #
 | |
|     # Data conversion and movement references
 | |
|     #
 | |
|     "clone",
 | |
|     "copy_to",  # TODO: add OpInfo (or implement .to)
 | |
|     "item",
 | |
|     "to",
 | |
|     #
 | |
|     # Reduction ops
 | |
|     #
 | |
|     "all",
 | |
|     "amax",
 | |
|     "amin",
 | |
|     "any",
 | |
|     "cumsum",
 | |
|     "cumprod",
 | |
|     "mean",
 | |
|     "dot",
 | |
|     "vdot",
 | |
|     "std",
 | |
|     "std_mean",
 | |
|     "sum",
 | |
|     "sum_to_size",
 | |
|     "prod",
 | |
|     "var",
 | |
|     "var_mean",
 | |
|     #
 | |
|     # Linear algebra ops
 | |
|     #
 | |
|     "addr",
 | |
|     #
 | |
|     # View & Shape Ops
 | |
|     #
 | |
|     "alias",
 | |
|     "alias_copy",
 | |
|     "atleast_1d",
 | |
|     "atleast_2d",
 | |
|     "atleast_3d",
 | |
|     "as_strided",
 | |
|     "as_strided_copy",
 | |
|     "as_strided_scatter",
 | |
|     "block_diag",
 | |
|     "broadcast_shapes",
 | |
|     "broadcast_tensors",
 | |
|     "broadcast_to",
 | |
|     "cat",
 | |
|     "chunk",
 | |
|     "column_stack",
 | |
|     "conj",
 | |
|     "constant_pad_nd",
 | |
|     "contiguous",
 | |
|     "diag_embed",
 | |
|     "diag",
 | |
|     "diagonal",
 | |
|     "diagonal_copy",
 | |
|     "diagonal_scatter",
 | |
|     "dsplit",
 | |
|     "dstack",
 | |
|     "expand",
 | |
|     "expand_as",
 | |
|     "expand_copy",
 | |
|     "flatten",
 | |
|     "flip",
 | |
|     "fliplr",
 | |
|     "flipud",
 | |
|     "hsplit",
 | |
|     "hstack",
 | |
|     "meshgrid",
 | |
|     "movedim",
 | |
|     "narrow",
 | |
|     "narrow_copy",
 | |
|     "native_group_norm",
 | |
|     "native_layer_norm",
 | |
|     "permute",
 | |
|     "permute_copy",
 | |
|     "ravel",
 | |
|     "repeat",
 | |
|     "reshape",
 | |
|     "reshape_as",
 | |
|     "roll",
 | |
|     "rot90",
 | |
|     "rsqrt",
 | |
|     "split_with_sizes",
 | |
|     "stack",
 | |
|     "swap_axes",  # alias for transpose
 | |
|     "squeeze",
 | |
|     "squeeze_copy",
 | |
|     "t",
 | |
|     "t_copy",
 | |
|     "T",
 | |
|     "take_along_dim",
 | |
|     "tensor_split",
 | |
|     "transpose",
 | |
|     "transpose_copy",
 | |
|     "unbind_copy",
 | |
|     "unfold",
 | |
|     "unfold_copy",
 | |
|     "unsqueeze",
 | |
|     "unsqueeze_copy",
 | |
|     "view",
 | |
|     "view_as",
 | |
|     "view_copy",
 | |
|     "vsplit",
 | |
|     "vstack",
 | |
|     "view_as_complex",
 | |
|     "unflatten",
 | |
|     "unbind",
 | |
|     "triu",
 | |
|     "tril",
 | |
|     "triu_indices",
 | |
|     "tril_indices",
 | |
|     #
 | |
|     # Tensor Creation
 | |
|     #
 | |
|     "arange",
 | |
|     "cauchy",
 | |
|     "empty",
 | |
|     "empty_like",
 | |
|     "empty_permuted",
 | |
|     "empty_strided",
 | |
|     "eye",
 | |
|     "full",
 | |
|     "full_like",
 | |
|     "linspace",
 | |
|     "logspace",
 | |
|     "new_empty",
 | |
|     "new_empty_strided",
 | |
|     "new_full",
 | |
|     "new_ones",
 | |
|     "new_zeros",
 | |
|     "ones",
 | |
|     "ones_like",
 | |
|     "randn",
 | |
|     "scalar_tensor",
 | |
|     "zero",
 | |
|     "zeros",
 | |
|     "zeros_like",
 | |
|     #
 | |
|     # Test-related functions
 | |
|     #
 | |
|     "allclose",
 | |
|     "equal",
 | |
|     #
 | |
|     # Statistical operations
 | |
|     #
 | |
|     "bucketize",
 | |
|     #
 | |
|     # Misc
 | |
|     #
 | |
|     "is_complex",
 | |
|     "renorm",
 | |
|     "stft",
 | |
|     "istft",
 | |
| ]
 | |
| 
 | |
| Tensor = torch.Tensor
 | |
| DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined]
 | |
| aten = torch._ops.ops.aten
 | |
| 
 | |
| # Note that the docstrings for the public methods from this file are in
 | |
| # torch/_torch_docs.py
 | |
| 
 | |
| 
 | |
| def is_noncontiguous_supported(device):
 | |
|     return device is None or device.type != "hpu"
 | |
| 
 | |
| 
 | |
| def handle_noncontiguous_outputs(input_tlist, output):
 | |
|     device = None
 | |
|     from torch._subclasses.fake_tensor import FakeTensor
 | |
| 
 | |
|     for t in input_tlist:
 | |
|         if isinstance(t, FakeTensor):
 | |
|             device = t.fake_device
 | |
|             break
 | |
| 
 | |
|     if not is_noncontiguous_supported(device):
 | |
|         output = output.contiguous()
 | |
| 
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def _broadcast_shapes(*_shapes):
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_or_false
 | |
| 
 | |
|     shapes = tuple(
 | |
|         (x,) if isinstance(x, IntLike) else x
 | |
|         for x in filter(lambda x: x is not None, _shapes)
 | |
|     )
 | |
| 
 | |
|     # Short-circuits on no input
 | |
|     if len(shapes) == 0:
 | |
|         return None
 | |
| 
 | |
|     # Type checking
 | |
|     # TODO: make common validations available as utils
 | |
|     for shape in shapes:
 | |
|         assert isinstance(shape, Sequence)
 | |
| 
 | |
|     # Computes common shape
 | |
|     common_shape: list[Union[int, torch.SymInt]] = [
 | |
|         1,
 | |
|     ] * reduce(max, (len(shape) for shape in shapes))
 | |
|     for arg_idx, shape in enumerate(shapes):
 | |
|         for idx in range(-1, -1 - len(shape), -1):
 | |
|             # if both 1, or statically known the same, we rather pick non-broadcast path.
 | |
|             if guard_or_false(common_shape[idx] == shape[idx]):
 | |
|                 continue
 | |
|             elif guard_or_false(common_shape[idx] == 1):
 | |
|                 if shape[idx] < 0:
 | |
|                     raise ValueError(
 | |
|                         "Attempting to broadcast a dimension with negative length!"
 | |
|                     )
 | |
|                 common_shape[idx] = shape[idx]
 | |
|             elif guard_or_false(shape[idx] == 1):
 | |
|                 # broadcast case .
 | |
|                 continue
 | |
|             else:
 | |
|                 # If broadcasting is undecided we pick non-broadcast path and add runtime assertion.
 | |
|                 torch._check(
 | |
|                     common_shape[idx] == shape[idx],
 | |
|                     lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
 | |
|                     f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
 | |
|                     f"should be broadcastable to {common_shape}",
 | |
|                 )
 | |
| 
 | |
|     return common_shape
 | |
| 
 | |
| 
 | |
| def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
 | |
|     # Computes common shape
 | |
|     common_shape = _broadcast_shapes(
 | |
|         *(t.shape if isinstance(t, TensorLike) else None for t in args)
 | |
|     )
 | |
| 
 | |
|     def __maybe_broadcast(x, shape):
 | |
|         if x is None:
 | |
|             return None
 | |
|         elif isinstance(x, Number):
 | |
|             return x
 | |
|         elif isinstance(x, TensorLike):
 | |
|             if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
 | |
|                 return x
 | |
| 
 | |
|             if not utils.same_shape(x.shape, common_shape):
 | |
|                 return x.expand(common_shape)
 | |
| 
 | |
|             return x
 | |
|         else:
 | |
|             raise RuntimeError(
 | |
|                 "Unexpected type when broadcasting: " + str(type(x)) + "!"
 | |
|             )
 | |
| 
 | |
|     return tuple(__maybe_broadcast(x, common_shape) for x in args)
 | |
| 
 | |
| 
 | |
| # Utilities should come BEFORE this import
 | |
| from torch._decomp import register_decomposition
 | |
| 
 | |
| 
 | |
| #
 | |
| # Elementwise unary references
 | |
| #
 | |
| 
 | |
| infer_aten_op = object()
 | |
| 
 | |
| 
 | |
| # TODO: add type promotion support
 | |
| def _make_elementwise_unary_reference(
 | |
|     type_promotion_kind,
 | |
|     *,
 | |
|     aten_op=infer_aten_op,
 | |
|     extra_meta=None,
 | |
|     exact_dtype=False,
 | |
| ) -> Callable:
 | |
|     def inner(prim: Callable):
 | |
|         nonlocal aten_op
 | |
| 
 | |
|         @wraps(prim)
 | |
|         @out_wrapper(exact_dtype=exact_dtype)
 | |
|         @elementwise_unary_scalar_wrapper
 | |
|         @elementwise_type_promotion_wrapper(
 | |
|             type_promoting_args=("a",),
 | |
|             type_promotion_kind=type_promotion_kind,
 | |
|         )
 | |
|         def _ref(a: TensorLikeType) -> TensorLikeType:
 | |
|             if extra_meta is not None:
 | |
|                 extra_meta(a)
 | |
| 
 | |
|             output = prim(a)
 | |
|             return handle_noncontiguous_outputs([a], output)
 | |
| 
 | |
|         if aten_op is infer_aten_op:
 | |
|             aten_op = utils.get_aten_op(prim, prim.__name__)
 | |
|         if aten_op is not None:
 | |
|             register_decomposition(aten_op)(_ref)
 | |
| 
 | |
|         return _ref
 | |
| 
 | |
|     return inner
 | |
| 
 | |
| 
 | |
| def _make_alias(fn, name):
 | |
|     """
 | |
|     This function defines an alias of another function and sets its __name__ argument.
 | |
|     It also sets its __module__ argument to the module of the caller.
 | |
|     Note that when naively doing `alias = fn`, we have that `alias.__name__ == "fn"`, and
 | |
|     `alias.__module__ == fn.__module__`.
 | |
|     """
 | |
| 
 | |
|     def _fn(*args, **kwargs):
 | |
|         return fn(*args, **kwargs)
 | |
| 
 | |
|     _fn.__name__ = name
 | |
|     _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"]  # type: ignore[union-attr]
 | |
|     return _fn
 | |
| 
 | |
| 
 | |
| def _make_inplace(fn):
 | |
|     """
 | |
|     Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant
 | |
|     See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch
 | |
|     """
 | |
| 
 | |
|     # nb. We use the name of the first argument used in the unary references
 | |
|     @wraps(fn)
 | |
|     def _fn(a, *args, **kwargs):
 | |
|         return fn(a, *args, out=a, **kwargs)
 | |
| 
 | |
|     inplace_name = f"{fn.__name__}_"
 | |
|     _fn.__name__ = inplace_name
 | |
|     _fn = register_decomposition(getattr(aten, inplace_name))(_fn)  # type: ignore[assignment]
 | |
| 
 | |
|     # We access the __all__ attribute of the module where fn is defined
 | |
|     # There may be a cleaner way of doing this...
 | |
|     from inspect import getmodule
 | |
| 
 | |
|     _all = getmodule(fn).__all__  # type: ignore[union-attr]
 | |
|     if inplace_name not in _all:
 | |
|         _all.append(inplace_name)
 | |
|     return _fn
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def abs(a):
 | |
|     return prims.abs(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def acos(a):
 | |
|     return prims.acos(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def acosh(a):
 | |
|     return prims.acosh(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def asin(a):
 | |
|     return prims.asin(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def asinh(a):
 | |
|     return prims.asinh(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def atan(a):
 | |
|     return prims.atan(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def atanh(a):
 | |
|     return prims.atanh(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
 | |
| def bitwise_not(a):
 | |
|     return prims.bitwise_not(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def ceil(a):
 | |
|     return prims.ceil(a)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.is_complex)
 | |
| def is_complex(input: TensorLikeType):
 | |
|     return utils.is_complex_dtype(input.dtype)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.conj_physical)
 | |
| @out_wrapper()
 | |
| def conj_physical(input: TensorLikeType):
 | |
|     if not utils.is_complex_dtype(input.dtype):
 | |
|         return input
 | |
|     return prims.conj_physical(input)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def cos(a):
 | |
|     return prims.cos(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def cosh(a):
 | |
|     return prims.cosh(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def digamma(a):
 | |
|     return prims.digamma(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def erf(a):
 | |
|     return prims.erf(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def erfinv(a):
 | |
|     return prims.erf_inv(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def erfc(a):
 | |
|     return prims.erfc(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def exp(a):
 | |
|     return prims.exp(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def expm1(a):
 | |
|     return prims.expm1(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def exp2(a):
 | |
|     return prims.exp2(a)
 | |
| 
 | |
| 
 | |
| # Fill has its own implementation because it has a value parameter
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("a,"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
 | |
| )
 | |
| def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
 | |
|     assert isinstance(a, TensorLike)
 | |
|     assert isinstance(value, Number)
 | |
| 
 | |
|     python_type = utils.dtype_to_type(a.dtype)
 | |
|     if not utils.is_weakly_lesser_type(type(value), python_type):
 | |
|         msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!"
 | |
|         raise ValueError(msg)
 | |
| 
 | |
|     return prims.fill(a, value)
 | |
| 
 | |
| 
 | |
| def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
 | |
|     r = prims.fill(a, value)
 | |
|     prims.copy_to(a, r)
 | |
|     return a
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.zero)
 | |
| @out_wrapper()
 | |
| def zero(input: TensorLikeType) -> TensorLikeType:
 | |
|     return torch.zeros_like(input)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def floor(a):
 | |
|     return prims.floor(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def frac(x: TensorLikeType) -> TensorLikeType:
 | |
|     trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x))
 | |
|     return torch.sub(x, trunc_x)
 | |
| 
 | |
| 
 | |
| # imag does not use _make_elementwise_unary_reference because it does not support out
 | |
| def imag(a: TensorLikeType) -> TensorLikeType:
 | |
|     assert isinstance(a, TensorLike)
 | |
|     torch._check(
 | |
|         utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
 | |
|     )
 | |
|     return prims.imag(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     aten_op=None,  # CompositeImplicitAutograd
 | |
| )
 | |
| def isfinite(a: TensorLikeType) -> TensorLikeType:
 | |
|     if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
 | |
|         return prims.isfinite(a)
 | |
| 
 | |
|     return ones_like(a, dtype=torch.bool)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
 | |
| def isinf(a: TensorLikeType) -> TensorLikeType:
 | |
|     if utils.is_complex_dtype(a.dtype):
 | |
|         return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a)))
 | |
|     if utils.is_float_dtype(a.dtype):
 | |
|         return torch.abs(a) == float("inf")
 | |
|     return torch.zeros_like(a, dtype=torch.bool)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def isposinf(a: TensorLikeType) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         not utils.is_complex_dtype(a.dtype),
 | |
|         lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
 | |
|     )
 | |
|     if utils.is_float_dtype(a.dtype):
 | |
|         return a == float("inf")
 | |
|     return torch.zeros_like(a, dtype=torch.bool)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def isneginf(a: TensorLikeType) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         not utils.is_complex_dtype(a.dtype),
 | |
|         lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
 | |
|     )
 | |
|     if utils.is_float_dtype(a.dtype):
 | |
|         return a == float("-inf")
 | |
|     return torch.zeros_like(a, dtype=torch.bool)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
 | |
| def isnan(a: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.ne(a, a)
 | |
| 
 | |
| 
 | |
| # alias
 | |
| mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma")  # type: ignore[has-type]
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     aten_op=None,  # CompositeImplicitAutograd
 | |
| )
 | |
| def isreal(a: TensorLikeType) -> TensorLikeType:
 | |
|     if utils.is_complex_dtype(a.dtype):
 | |
|         return torch.imag(a) == 0
 | |
|     return torch.ones_like(a, dtype=torch.bool)
 | |
| 
 | |
| 
 | |
| # TODO: if this is special maybe it should be defined there and imported here?
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0
 | |
| )
 | |
| def i0(a):
 | |
|     return prims.bessel_i0(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def lgamma(a):
 | |
|     return prims.lgamma(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def log(a):
 | |
|     return prims.log(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def log1p(a):
 | |
|     return prims.log1p(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def log2(a):
 | |
|     return prims.log2(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def log10(a):
 | |
|     return prims.log10(a)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @out_wrapper()
 | |
| def log_softmax(
 | |
|     a: TensorLikeType,
 | |
|     dim: int,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
| ) -> TensorLikeType:
 | |
|     result_dtype = dtype or a.dtype
 | |
|     computation_dtype = utils.get_computation_dtype(result_dtype)
 | |
|     a_ = _maybe_convert_to_dtype(a, computation_dtype)
 | |
|     return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype)  # type: ignore[return-value]
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.logsumexp)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self",),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
| )
 | |
| def logsumexp(
 | |
|     self: TensorLikeType, dim: DimsType, keepdim: bool = False
 | |
| ) -> TensorLikeType:
 | |
|     if not isinstance(dim, Iterable):
 | |
|         dim = (dim,)
 | |
|     if self.numel() == 0:
 | |
|         return torch.sum(torch.exp(self), dim, keepdim).log()
 | |
|     maxes = torch.amax(torch.real(self), dim, keepdim=True)
 | |
|     maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
 | |
|     maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
 | |
|     result = torch.sum(torch.exp(self - maxes), dim, keepdim)
 | |
|     return result.log().add(maxes_squeezed)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.nan_to_num)
 | |
| @out_wrapper()
 | |
| def nan_to_num(
 | |
|     a: TensorLikeType,
 | |
|     nan: Optional[NumberType] = 0.0,
 | |
|     posinf: Optional[NumberType] = None,
 | |
|     neginf: Optional[NumberType] = None,
 | |
| ) -> TensorLikeType:
 | |
|     assert isinstance(a, TensorLike)
 | |
| 
 | |
|     if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
 | |
|         return a.clone()
 | |
| 
 | |
|     if nan is None:
 | |
|         nan = 0.0
 | |
| 
 | |
|     if posinf is None:
 | |
|         posinf = torch.finfo(a.dtype).max
 | |
| 
 | |
|     if neginf is None:
 | |
|         neginf = torch.finfo(a.dtype).min
 | |
| 
 | |
|     result = torch.where(torch.isnan(a), nan, a)  # type: ignore[call-overload]
 | |
|     result = torch.where(torch.isneginf(a), neginf, result)  # type: ignore[call-overload]
 | |
|     result = torch.where(torch.isposinf(a), posinf, result)  # type: ignore[call-overload]
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def _neg_meta(a: TensorLikeType):
 | |
|     torch._check(
 | |
|         a.dtype is not torch.bool,
 | |
|         lambda: (
 | |
|             "Negation, the `-` operator, on a bool tensor is not supported. "
 | |
|             "If you are trying to invert a mask, use the `~` or `logical_not()` "
 | |
|             "operator instead."
 | |
|         ),
 | |
|     )
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta
 | |
| )
 | |
| def neg(a):
 | |
|     return prims.neg(a)
 | |
| 
 | |
| 
 | |
| # positive does not use _make_elementwise_unary_reference because it does not support out
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def positive(a: TensorLikeType) -> TensorLikeType:
 | |
|     assert isinstance(a, TensorLike)
 | |
|     if a.dtype is torch.bool:
 | |
|         msg = "positive does not support bool tensors."
 | |
|         raise RuntimeError(msg)
 | |
|     return a
 | |
| 
 | |
| 
 | |
| # real does not use _make_elementwise_unary_reference because it does not support out
 | |
| def real(a: TensorLikeType) -> TensorLikeType:
 | |
|     assert isinstance(a, TensorLike)
 | |
|     if utils.is_complex_dtype(a.dtype):
 | |
|         return prims.real(a)
 | |
|     return a
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def reciprocal(a):
 | |
|     return prims.reciprocal(a)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.round)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("a",),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType:
 | |
|     if decimals == 0:
 | |
|         return prims.round(a)
 | |
|     else:
 | |
|         ten_pow = 10**decimals
 | |
|         ten_neg_pow = 10 ** (-decimals)
 | |
|         return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def rsqrt(a):
 | |
|     return prims.rsqrt(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def sigmoid(a: TensorLikeType) -> TensorLikeType:
 | |
|     return true_divide(1, add(1, exp(neg(a))))
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def sgn(a):
 | |
|     if utils.is_complex_dtype(a.dtype):
 | |
|         a_abs = a.abs()
 | |
|         return torch.where(a_abs == 0, 0, a / a_abs)
 | |
|     else:
 | |
|         return a.sign()
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def sign(a):
 | |
|     return prims.sign(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def signbit(a):
 | |
|     return prims.signbit(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def sin(a):
 | |
|     return prims.sin(a)
 | |
| 
 | |
| 
 | |
| # Autograd note: This will give the right first derivative at zero (by chance),
 | |
| # but not the right second derivative
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def sinc(a):
 | |
|     a = math.pi * a
 | |
|     return torch.where(a == 0, 1, torch.sin(a) / a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def sinh(a):
 | |
|     return prims.sinh(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def sqrt(a):
 | |
|     return prims.sqrt(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
 | |
|     aten_op=None,  # CompositeImplicitAutograd,
 | |
| )
 | |
| def square(a: TensorLikeType) -> TensorLikeType:
 | |
|     return mul(a, a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def tan(a):
 | |
|     return prims.tan(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def tanh(a):
 | |
|     return prims.tanh(a)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(
 | |
|     ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     exact_dtype=True,
 | |
| )
 | |
| def trunc(a):
 | |
|     return prims.trunc(a)
 | |
| 
 | |
| 
 | |
| # TODO: register this as a real ref/decomposition once TorchInductor supports complex!
 | |
| def view_as_complex(self: TensorLikeType) -> TensorLikeType:
 | |
|     input_dtype = self.dtype
 | |
|     torch._check(
 | |
|         utils.is_float_dtype(input_dtype),
 | |
|         lambda: f"view_as_complex is only supported for floating point"
 | |
|         f"tensors, but got a tensor of scalar type: {input_dtype}",
 | |
|     )
 | |
|     sizes = self.size()
 | |
|     torch._check(
 | |
|         len(sizes) != 0,
 | |
|         lambda: "Input tensor must have one or more dimensions",
 | |
|     )
 | |
|     torch._check(
 | |
|         sizes[-1] == 2,
 | |
|         lambda: "Tensor must have a last dimension of size 2",
 | |
|     )
 | |
| 
 | |
|     old_strides = self.stride()
 | |
|     torch._check(
 | |
|         old_strides[-1] == 1,
 | |
|         lambda: "Tensor must have a last dimension with stride 1",
 | |
|     )
 | |
|     dims = old_strides[:-1]
 | |
|     torch._check(
 | |
|         builtins.all(stride % 2 == 0 for stride in dims),
 | |
|         lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
 | |
|     )
 | |
|     torch._check(
 | |
|         self.storage_offset() % 2 == 0,
 | |
|         lambda: "Tensor must have a storage_offset divisible by 2",
 | |
|     )
 | |
|     return prims.view_element_type(
 | |
|         self, utils.corresponding_complex_dtype(input_dtype)
 | |
|     ).squeeze(-1)
 | |
| 
 | |
| 
 | |
| def _make_elementwise_binary_reference(
 | |
|     type_promotion_kind,
 | |
|     aten_op=infer_aten_op,
 | |
|     name=None,
 | |
|     has_out=True,
 | |
|     supports_lhs_python_scalar=True,
 | |
|     supports_rhs_python_scalar=True,
 | |
|     supports_two_python_scalars=False,
 | |
|     should_register_decomposition=True,
 | |
| ) -> Callable:
 | |
|     def inner(prim: Callable):
 | |
|         nonlocal aten_op, name
 | |
|         if name is None:
 | |
|             name = prim.__name__
 | |
| 
 | |
|         @wraps(prim)
 | |
|         @elementwise_type_promotion_wrapper(
 | |
|             type_promoting_args=("a", "b"),
 | |
|             type_promotion_kind=type_promotion_kind,
 | |
|         )
 | |
|         def _ref(
 | |
|             a: Union[Tensor, NumberType],
 | |
|             b: Union[Tensor, NumberType],
 | |
|         ) -> Tensor:
 | |
|             torch._check_value(
 | |
|                 supports_lhs_python_scalar or not isinstance(a, Number),
 | |
|                 lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
 | |
|                 "operation that does not accept lhs scalars!",
 | |
|             )
 | |
|             torch._check_value(
 | |
|                 supports_rhs_python_scalar or not isinstance(b, Number),
 | |
|                 lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
 | |
|                 "operation that does not accept rhs scalars!",
 | |
|             )
 | |
|             torch._check_value(
 | |
|                 supports_two_python_scalars
 | |
|                 or not (isinstance(a, Number) and isinstance(b, Number)),
 | |
|                 lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
 | |
|             )
 | |
|             a, b = _maybe_broadcast(a, b)
 | |
|             output = prim(a, b)
 | |
|             return handle_noncontiguous_outputs([a, b], output)
 | |
| 
 | |
|         if has_out:
 | |
|             _ref = out_wrapper()(_ref)  # type: ignore[assignment]
 | |
| 
 | |
|         _ref.__name__ = name
 | |
|         if aten_op is infer_aten_op:
 | |
|             aten_op = utils.get_aten_op(prim, name)
 | |
|         if aten_op is not None and should_register_decomposition:
 | |
|             register_decomposition(aten_op)(_ref)
 | |
| 
 | |
|         return _ref
 | |
| 
 | |
|     return inner
 | |
| 
 | |
| 
 | |
| # Add has its own implementation because it has an alpha argument
 | |
| @register_decomposition(aten.add)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("a", "b"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def add(
 | |
|     a: Union[TensorLikeType, NumberType],
 | |
|     b: Union[TensorLikeType, NumberType],
 | |
|     *,
 | |
|     alpha: Optional[NumberType] = None,
 | |
| ):
 | |
|     """
 | |
|     Reference implementation of torch.add
 | |
|     """
 | |
| 
 | |
|     a, b = _maybe_broadcast(a, b)
 | |
| 
 | |
|     if alpha is not None:
 | |
|         dtype = a.dtype if isinstance(a, TensorLike) else b.dtype  # type: ignore[union-attr]
 | |
|         python_type = utils.dtype_to_type(dtype)
 | |
|         if python_type != bool and not utils.is_weakly_lesser_type(
 | |
|             type(alpha), python_type
 | |
|         ):
 | |
|             msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
 | |
|             raise ValueError(msg)
 | |
|         if isinstance(b, TensorLike):
 | |
|             b = prims.mul(b, alpha)
 | |
|         else:
 | |
|             b = b * alpha
 | |
| 
 | |
|     output = prims.add(a, b)
 | |
|     return handle_noncontiguous_outputs([a, b], output)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def atan2(a, b):
 | |
|     return prims.atan2(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.bitwise_and(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.shift_left(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.bitwise_or(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.shift_right_arithmetic(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.bitwise_xor(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|     supports_lhs_python_scalar=False,
 | |
| )
 | |
| def copysign(
 | |
|     a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
 | |
| ):
 | |
|     if isinstance(b, Number) and isinstance(a, Tensor):
 | |
|         b = scalar_tensor(b, dtype=a.dtype, device=a.device)
 | |
|     elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
 | |
|         msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
 | |
|         raise RuntimeError(msg)
 | |
|     return where(signbit(b), neg(abs(a)), abs(a))
 | |
| 
 | |
| 
 | |
| # complex =  _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.div)
 | |
| @out_wrapper()
 | |
| def div(
 | |
|     a: Union[TensorLikeType, NumberType],
 | |
|     b: Union[TensorLikeType, NumberType],
 | |
|     *,
 | |
|     rounding_mode: Optional[str] = None,
 | |
| ):
 | |
|     """
 | |
|     Reference implementation of torch.div
 | |
|     """
 | |
|     if rounding_mode is None:
 | |
|         return true_divide(a, b)
 | |
|     elif rounding_mode == "trunc":
 | |
|         return trunc_divide(a, b)
 | |
|     elif rounding_mode == "floor":
 | |
|         return floor_divide(a, b)
 | |
|     else:
 | |
|         msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
 | |
|         raise ValueError(msg)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     supports_lhs_python_scalar=False,
 | |
| )
 | |
| def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.eq(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
 | |
| )
 | |
| def pow(
 | |
|     a: Union[TensorLikeType, NumberType],
 | |
|     b: Union[TensorLikeType, NumberType],
 | |
| ) -> TensorLikeType:
 | |
|     assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)
 | |
| 
 | |
|     if isinstance(b, Number):
 | |
|         if b == 1.0:
 | |
|             return a.clone()  # type: ignore[return-value,union-attr]
 | |
|         elif b == 2.0:
 | |
|             return a * a  # type: ignore[return-value]
 | |
|         elif b == 0.5:
 | |
|             return torch.sqrt(a)  # type: ignore[arg-type]
 | |
|     elif isinstance(a, Number):
 | |
|         if a == 1.0:
 | |
|             return torch.fill(b, True)
 | |
|         if a == 2.0 and (
 | |
|             utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype)
 | |
|         ):
 | |
|             return torch.exp2(b)
 | |
| 
 | |
|     return prims.pow(a, b)
 | |
| 
 | |
| 
 | |
| # Float power has its own implementation because it has unique type promotion.
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @out_wrapper()
 | |
| def float_power(
 | |
|     a: Union[TensorLikeType, NumberType],
 | |
|     b: Union[TensorLikeType, NumberType],
 | |
| ) -> Tensor:
 | |
|     if isinstance(a, Number) and isinstance(b, Number):
 | |
|         raise ValueError(
 | |
|             "Receive two Number inputs to an elementwise binary operation!"
 | |
|         )
 | |
| 
 | |
|     # Handles type promotion
 | |
|     dtype = utils.get_higher_dtype(a, b)
 | |
|     assert dtype is not None
 | |
|     if utils.is_complex_dtype(dtype):
 | |
|         dtype = torch.complex128
 | |
|     else:
 | |
|         dtype = torch.float64
 | |
| 
 | |
|     # Float power has the following contiguous cast behavior to be
 | |
|     # consistent with its C++ impl
 | |
|     a = _maybe_convert_to_dtype(a, dtype)
 | |
|     b = _maybe_convert_to_dtype(b, dtype)
 | |
| 
 | |
|     a, b = _maybe_broadcast(a, b)
 | |
|     return pow(a, b)
 | |
| 
 | |
| 
 | |
| # >>> a = torch.tensor(-0.2500, dtype=torch.float64)
 | |
| # tensor(-0.250000000000000, dtype=torch.float64)
 | |
| #
 | |
| # >>> b = torch.tensor(-0.0010, dtype=torch.float64)
 | |
| # tensor(-0.001000000000000, dtype=torch.float64)
 | |
| #
 | |
| # Note: In this case, casting float to double will expand the float mantissa with zeros,
 | |
| # while creating a double generates a distinct mantissa.
 | |
| # >>> torch.tensor(-0.001).to(dtype=torch.float64)
 | |
| # tensor(-0.001000000047497, dtype=torch.float64)
 | |
| #
 | |
| # Floor Division
 | |
| # The difference is caused because torch.remainder(a, b) = -0.001.
 | |
| #
 | |
| # >>> torch.floor(torch.true_divide(a, b))
 | |
| # tensor(250., dtype=torch.float64)
 | |
| #
 | |
| # >>> torch.div(a, b, rounding_mode='floor')
 | |
| # tensor(249., dtype=torch.float64)
 | |
| #
 | |
| # Definition: a // b = (a - remainder(a, b)) / b
 | |
| # >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b)
 | |
| # tensor(249., dtype=torch.float64)
 | |
| #
 | |
| # For reference, see CPython's implementation:
 | |
| # https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_two_python_scalars=True,
 | |
|     should_register_decomposition=False,
 | |
| )
 | |
| def floor_divide(
 | |
|     a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
 | |
| ):
 | |
|     # Wrap scalars because some references only accept tensor arguments.
 | |
|     if isinstance(a, Number) and isinstance(b, Number):
 | |
|         a = scalar_tensor(a)
 | |
|         b = scalar_tensor(b)
 | |
|     elif isinstance(b, Number) and isinstance(a, Tensor):
 | |
|         b = scalar_tensor(b, dtype=a.dtype, device=a.device)
 | |
|     elif isinstance(a, Number) and isinstance(b, Tensor):
 | |
|         a = scalar_tensor(a, dtype=b.dtype, device=b.device)
 | |
|     elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
 | |
|         if a.device == torch.device("cpu"):
 | |
|             msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
 | |
|             raise RuntimeError(msg)
 | |
|         else:
 | |
|             b = prims.device_put(b, device=a.device)
 | |
| 
 | |
|     assert isinstance(a, Tensor) and isinstance(b, Tensor)
 | |
|     dtype = a.dtype
 | |
|     if utils.is_float_dtype(dtype):
 | |
|         return _floor_divide_float(a, b)
 | |
|     elif utils.is_integer_dtype(dtype):
 | |
|         return _floor_divide_integer(a, b)
 | |
|     else:
 | |
|         torch._check(False, lambda: f"{dtype} not supported for floor_divide")
 | |
| 
 | |
| 
 | |
| def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
 | |
|     a, b = _maybe_broadcast(a, b)
 | |
| 
 | |
|     if not a.dtype.is_signed:
 | |
|         return prims.div(a, b)
 | |
| 
 | |
|     # Convert truncation to flooring:
 | |
|     offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
 | |
|     return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
 | |
| 
 | |
| 
 | |
| def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor:
 | |
|     mod = fmod(a, b)
 | |
|     div = true_divide(sub(a, mod), b)
 | |
| 
 | |
|     # Ensure that the remainder has the same sign as denominator
 | |
|     different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0))
 | |
|     non_zero_remainder = ne(mod, 0)
 | |
|     mask = bitwise_and(non_zero_remainder, different_signed_inputs)
 | |
|     div = where(mask, sub(div, 1), div)
 | |
| 
 | |
|     # Map quotient to nearest integer value
 | |
|     floor_div = floor(div)
 | |
|     mask = gt(sub(div, floor_div), 0.5)
 | |
|     floor_div = where(mask, add(floor_div, 1), floor_div)
 | |
| 
 | |
|     basic_div = true_divide(a, b)
 | |
|     zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device)
 | |
| 
 | |
|     # If quotient is zero, copy signbit from true_divide quotient
 | |
|     floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div))
 | |
| 
 | |
|     # If denominator is zero, then follow true_divide behavior
 | |
|     return where(ne(b, 0), floor_div, basic_div)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.fmax(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.fmin(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=True,
 | |
| )
 | |
| def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.fmod(a, b)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.frexp)
 | |
| @out_wrapper("mantissa", "exponent")
 | |
| def frexp(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]:
 | |
|     return torch.return_types.frexp(prims.frexp(self))
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.gcd(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     supports_lhs_python_scalar=False,
 | |
| )
 | |
| def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.ge(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     supports_lhs_python_scalar=False,
 | |
| )
 | |
| def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.gt(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType:
 | |
|     input_eq_zero = torch.eq(input, 0)
 | |
|     input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input))
 | |
|     zeros_and_ones = torch.where(input_lt_zero, 0, 1)
 | |
|     output = torch.where(input_eq_zero, values, zeros_and_ones)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.hypot(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.igamma(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.igammac(a, b)
 | |
| 
 | |
| 
 | |
| def _check_close_args(
 | |
|     name: str,
 | |
|     a: TensorLikeType,
 | |
|     b: TensorLikeType,
 | |
|     rtol: float,
 | |
|     atol: float,
 | |
| ) -> None:
 | |
|     torch._check_value(
 | |
|         a.dtype == b.dtype,
 | |
|         lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!",
 | |
|     )
 | |
|     torch._check(
 | |
|         rtol >= 0,
 | |
|         lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!",
 | |
|     )
 | |
|     torch._check(
 | |
|         atol >= 0,
 | |
|         lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!",
 | |
|     )
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def isclose(
 | |
|     a: TensorLikeType,
 | |
|     b: TensorLikeType,
 | |
|     rtol: float = 1e-05,
 | |
|     atol: float = 1e-08,
 | |
|     equal_nan: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol)
 | |
| 
 | |
|     close = eq(a, b)
 | |
|     if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
 | |
|         close = logical_or(close, logical_and(isnan(a), isnan(b)))
 | |
| 
 | |
|     # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
 | |
|     # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
 | |
|     if atol == 0 and rtol == 0:
 | |
|         return close
 | |
| 
 | |
|     # Note [closeness error computation]
 | |
|     # atol and rtol are provided as doubles, so the computation
 | |
|     # rtol * other will produce a float or complex tensor.
 | |
|     # When the difference (self - other) is compared to it then the
 | |
|     # tensor representing the difference will also be cast to float or complex.
 | |
|     # However, since (self - other) in uint8 is very likely to produce a
 | |
|     # negative value, this moves the cast forward so the difference is
 | |
|     # always computed in a float or complex type.
 | |
|     # If the values of the integer tensors cannot be exactly represented
 | |
|     # by the default scalar type then this may cause an incorrect result.
 | |
|     if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
 | |
|         a = prims.convert_element_type(a, torch.get_default_dtype())
 | |
|         b = prims.convert_element_type(b, torch.get_default_dtype())
 | |
| 
 | |
|     allowed_error = add(atol, abs(mul(b, rtol)))
 | |
|     actual_error = abs(sub(a, b))
 | |
| 
 | |
|     # Computes finite closeness
 | |
|     result = logical_or(
 | |
|         close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
 | |
|     )
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def lcm(a: TensorLikeType, b: TensorLikeType):
 | |
|     dtype = a.dtype
 | |
|     # promoting to int32 to maintain 100% consistency with C++ and to
 | |
|     # prevent overflow in case of int8 and int16
 | |
|     promote_to_int = dtype in (torch.int8, torch.int16)
 | |
|     if promote_to_int:
 | |
|         a = prims.convert_element_type(a, torch.int32)
 | |
|         b = prims.convert_element_type(b, torch.int32)
 | |
| 
 | |
|     g = torch.gcd(a, b)
 | |
|     # Avoid division by zero in case gcd(0, 0) == 0
 | |
|     g = torch.where(g == 0, 1, g)
 | |
|     res = torch.abs(prims.div(a, g) * b)
 | |
|     return res if not promote_to_int else prims.convert_element_type(res, dtype)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     supports_lhs_python_scalar=False,
 | |
| )
 | |
| def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.le(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     # Nb. this implementation does not distribute the gradients evenly when a == b
 | |
|     mask = torch.real(a) >= torch.real(b)
 | |
|     max_ = torch.where(mask, a, b)
 | |
|     min_ = torch.where(mask, b, a)
 | |
|     inf_mask = torch.logical_and(
 | |
|         torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b)
 | |
|     )
 | |
|     if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype):
 | |
|         # are you wondering what this bunch of codes are for? edge cases!
 | |
|         neg_min_mask = torch.real(min_) < 0
 | |
|         inf_vals = torch.where(
 | |
|             neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_))
 | |
|         )
 | |
|         non_nan_vals = torch.where(
 | |
|             inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_))
 | |
|         )
 | |
|         # the type for full_like does not include tensor yet
 | |
|         nan_mask = torch.isnan(min_)
 | |
|         return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals)  # type: ignore[call-overload]
 | |
|     else:
 | |
|         return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_)))
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)),
 | |
|         lambda: "logaddexp2 doesn't support complex dtypes",
 | |
|     )
 | |
|     # Nb. this implementation does not distribute the gradients evenly when a == b
 | |
|     mask = a >= b
 | |
|     max_ = torch.where(mask, a, b)
 | |
|     min_ = torch.where(mask, b, a)
 | |
|     inf_mask = torch.logical_and(torch.isinf(a), a == b)
 | |
|     inv_log_2 = 1.0 / math.log(2)
 | |
|     result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2
 | |
|     return torch.where(inf_mask, a, result)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
| )
 | |
| def logical_and(a: TensorLikeType, b: TensorLikeType):
 | |
|     if not utils.is_boolean_dtype(a.dtype):
 | |
|         a = a != 0
 | |
|     if not utils.is_boolean_dtype(b.dtype):
 | |
|         b = b != 0
 | |
|     return a & b
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
 | |
| def logical_not(a: TensorLikeType):
 | |
|     if not utils.is_boolean_dtype(a.dtype):
 | |
|         return a == 0
 | |
|     return ~a
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
| )
 | |
| def logical_or(a: TensorLikeType, b: TensorLikeType):
 | |
|     if not utils.is_boolean_dtype(a.dtype):
 | |
|         a = a != 0
 | |
|     if not utils.is_boolean_dtype(b.dtype):
 | |
|         b = b != 0
 | |
|     return bitwise_or(a, b)
 | |
| 
 | |
| 
 | |
| # TODO: skip unnecessary conversion of long to float
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
| )
 | |
| def logical_xor(a: TensorLikeType, b: TensorLikeType):
 | |
|     if not utils.is_boolean_dtype(a.dtype):
 | |
|         a = a != 0
 | |
|     if not utils.is_boolean_dtype(b.dtype):
 | |
|         b = b != 0
 | |
|     return a ^ b
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     supports_lhs_python_scalar=False,
 | |
| )
 | |
| def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.lt(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.maximum(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.minimum(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     supports_two_python_scalars=True,
 | |
| )
 | |
| def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.mul(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
 | |
|     supports_lhs_python_scalar=False,
 | |
| )
 | |
| def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.ne(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
 | |
|     supports_lhs_python_scalar=False,
 | |
|     supports_rhs_python_scalar=False,
 | |
| )
 | |
| def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.nextafter(a, b)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.remainder(a, b)
 | |
| 
 | |
| 
 | |
| # reverse sub
 | |
| @register_decomposition(aten.rsub)
 | |
| @out_wrapper()
 | |
| def rsub(
 | |
|     a: Union[TensorLikeType, NumberType],
 | |
|     b: Union[TensorLikeType, NumberType],
 | |
|     alpha: NumberType = 1,
 | |
| ):
 | |
|     if isinstance(a, Number):
 | |
|         msg = "Received a Number for the first argument, but expected a Tensor"
 | |
|         raise ValueError(msg)
 | |
| 
 | |
|     return torch.sub(b, a, alpha=alpha)
 | |
| 
 | |
| 
 | |
| # TODO: consider refactoring this with add impl
 | |
| # sub has its own implementation because it has an alpha argument
 | |
| @register_decomposition(aten.sub)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("a", "b"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def sub(
 | |
|     a: Union[TensorLikeType, NumberType],
 | |
|     b: Union[TensorLikeType, NumberType],
 | |
|     *,
 | |
|     alpha: NumberType = 1,
 | |
| ):
 | |
|     """
 | |
|     Reference implementation of torch.sub
 | |
|     """
 | |
| 
 | |
|     a, b = _maybe_broadcast(a, b)
 | |
| 
 | |
|     if isinstance(a, TensorLike) and isinstance(b, TensorLike):
 | |
|         torch._check(
 | |
|             not utils.is_boolean_dtype(a.dtype) and not utils.is_boolean_dtype(b.dtype),
 | |
|             lambda: (
 | |
|                 "Subtraction, the `-` operator, with two bool tensors is not supported. "
 | |
|                 "Use the `^` or `logical_xor()` operator instead."
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     if alpha != 1:
 | |
|         dtype = a.dtype if isinstance(a, TensorLike) else b.dtype  # type: ignore[union-attr]
 | |
|         python_type = utils.dtype_to_type(dtype)
 | |
|         if not utils.is_weakly_lesser_type(type(alpha), python_type):
 | |
|             msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
 | |
|             raise ValueError(msg)
 | |
|         if isinstance(b, torch.Tensor):
 | |
|             b = prims.mul(b, alpha)
 | |
|         else:
 | |
|             # Carefully not to use prims.mul if b is a scalar / symint.
 | |
|             # prims.mul always returns a tensor,
 | |
|             # which will mess with type promotion.
 | |
|             b = b * alpha
 | |
| 
 | |
|     output = prims.sub(a, b)
 | |
|     return handle_noncontiguous_outputs([a, b], output)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|     name="true_divide",
 | |
|     aten_op=None,  # CompositeImplicitAutograd
 | |
|     supports_two_python_scalars=True,
 | |
| )
 | |
| def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.div(a, b)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.xlogy)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("a", "b"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
| )
 | |
| def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
 | |
|     torch._check(
 | |
|         isinstance(a, TensorLike) or isinstance(b, TensorLike),
 | |
|         lambda: 'Expected either argument a or b to be a Tensor"',
 | |
|     )
 | |
| 
 | |
|     # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
 | |
|     if isinstance(b, TensorLike) and isinstance(a, Number):
 | |
|         a = scalar_tensor(a, dtype=b.dtype, device=b.device)
 | |
|     elif isinstance(a, TensorLike) and isinstance(b, Number):
 | |
|         b = scalar_tensor(b, dtype=a.dtype, device=a.device)
 | |
| 
 | |
|     # mypy: expected "Tensor"
 | |
|     assert isinstance(a, TensorLike)
 | |
|     assert isinstance(b, TensorLike)
 | |
|     rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b)))
 | |
|     return torch.where(torch.isnan(b), float("nan"), rhs)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_binary_reference(
 | |
|     type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     aten_op=None,  # CompositeImplicitAutograd
 | |
|     supports_two_python_scalars=True,
 | |
| )
 | |
| def trunc_divide(
 | |
|     a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
 | |
| ):
 | |
|     dtype = utils.get_dtype(a)
 | |
|     if utils.is_integer_dtype(dtype):
 | |
|         return prims.div(a, b)
 | |
| 
 | |
|     return trunc(prims.div(a, b))
 | |
| 
 | |
| 
 | |
| #
 | |
| # Elementwise Ternary References
 | |
| #
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.addcdiv)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self", "tensor1", "tensor2"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
| )
 | |
| def addcdiv(
 | |
|     self: TensorLikeType,
 | |
|     tensor1: TensorLikeType,
 | |
|     tensor2: TensorLikeType,
 | |
|     *,
 | |
|     value: NumberType = 1,
 | |
| ) -> TensorLikeType:
 | |
|     """
 | |
|     Reference implementation of torch.addcdiv
 | |
|     """
 | |
|     if value is not None:
 | |
|         dtype = self.dtype  # no scalars allowed, see add
 | |
|         python_type = utils.dtype_to_type(dtype)
 | |
|         torch._check_value(
 | |
|             utils.is_weakly_lesser_type(type(value), python_type),
 | |
|             lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
 | |
|         )
 | |
| 
 | |
|     return self + value * tensor1 / tensor2
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.addcmul)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self", "tensor1", "tensor2"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def addcmul(
 | |
|     self: TensorLikeType,
 | |
|     tensor1: TensorLikeType,
 | |
|     tensor2: TensorLikeType,
 | |
|     *,
 | |
|     value: NumberType = 1,
 | |
| ) -> TensorLikeType:
 | |
|     """
 | |
|     Reference implementation of torch.addcmul
 | |
|     """
 | |
|     if value is not None:
 | |
|         dtype = self.dtype  # no scalars allowed, see add
 | |
|         python_type = utils.dtype_to_type(dtype)
 | |
|         torch._check_value(
 | |
|             utils.is_weakly_lesser_type(type(value), python_type),
 | |
|             lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
 | |
|         )
 | |
| 
 | |
|     return self + value * tensor1 * tensor2
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.clamp)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("a", "min", "max"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def clamp(
 | |
|     a: TensorLikeType,
 | |
|     min: Optional[TensorOrNumberLikeType] = None,
 | |
|     max: Optional[TensorOrNumberLikeType] = None,
 | |
| ) -> TensorLikeType:
 | |
|     # NOTE: grad behavior with implementation `where` is not consistent on `nan`
 | |
|     if min is None and max is None:
 | |
|         msg = "clamp called but both min and max are none!"
 | |
|         raise ValueError(msg)
 | |
|     if min is not None:
 | |
|         a_isnan = torch.isnan(a)
 | |
|         condition = torch.bitwise_or(torch.ge(a, min), a_isnan)  # type: ignore[arg-type]
 | |
|         # we should also propagate `nan` coming from boundaries. However, that's
 | |
|         # not necessary since `ge` would already `False` when either operands has
 | |
|         # a `nan`. So this line below is redundant
 | |
|         #   `condition = bitwise_and(condition, bitwise_not(isnan(min)))`
 | |
|         a = torch.where(condition, a, min)  # type: ignore[arg-type]
 | |
|     if max is not None:
 | |
|         a_isnan = torch.isnan(a)
 | |
|         # same as above, no need to adjust `nan` from `max`
 | |
|         condition = torch.bitwise_or(torch.le(a, max), a_isnan)  # type: ignore[arg-type]
 | |
|         a = torch.where(condition, a, max)  # type: ignore[arg-type]
 | |
| 
 | |
|     return a
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.clamp_min)
 | |
| @out_wrapper()
 | |
| def clamp_min(
 | |
|     self: TensorLikeType,
 | |
|     min: Optional[TensorOrNumberLikeType] = None,
 | |
| ) -> TensorLikeType:
 | |
|     return torch.clamp(self, min=min)  # type: ignore[arg-type]
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.clamp_max)
 | |
| @out_wrapper()
 | |
| def clamp_max(
 | |
|     self: TensorLikeType,
 | |
|     max: Optional[TensorOrNumberLikeType] = None,
 | |
| ) -> TensorLikeType:
 | |
|     return torch.clamp(self, max=max)  # type: ignore[arg-type]
 | |
| 
 | |
| 
 | |
| #
 | |
| # Conditional references
 | |
| #
 | |
| 
 | |
| 
 | |
| # https://pytorch.org/docs/stable/generated/torch.where.html
 | |
| # TODO: implement alternate where
 | |
| @register_decomposition(aten.where)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("a", "b"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
 | |
| )
 | |
| def where(
 | |
|     pred: Tensor,
 | |
|     a: Optional[TensorOrNumberLikeType] = None,
 | |
|     b: Optional[TensorOrNumberLikeType] = None,
 | |
| ):
 | |
|     """ """
 | |
| 
 | |
|     if a is None or b is None:
 | |
|         raise NotImplementedError
 | |
| 
 | |
|     utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
 | |
|     torch._check(
 | |
|         pred.dtype is torch.bool,
 | |
|         lambda: f"expected predicate to be bool, got {pred.dtype}",
 | |
|     )
 | |
| 
 | |
|     pred, a, b = _maybe_broadcast(pred, a, b)
 | |
|     return prims.where(pred, a, b)
 | |
| 
 | |
| 
 | |
| #
 | |
| # Data Movement References
 | |
| #
 | |
| @register_decomposition(aten.clone)
 | |
| @out_wrapper()
 | |
| def clone(
 | |
|     a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
 | |
| ) -> TensorLikeType:
 | |
|     result = prims.clone(a, memory_format=memory_format)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
 | |
|     if not allow_cross_device and a.device != b.device:
 | |
|         msg = f"Attempting to copy from device {b.device} to device {a.device}, but cross-device copies are not allowed!"
 | |
|         raise RuntimeError(msg)
 | |
| 
 | |
|     return prims.copy_to(a, b)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.item)
 | |
| def item(a: TensorLikeType) -> NumberType:
 | |
|     if a.numel() != 1:
 | |
|         msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
 | |
|         raise ValueError(msg)
 | |
| 
 | |
|     # NOTE: explicit conversion is necessary for bool!
 | |
|     # See https://github.com/pytorch/pytorch/issues/78071
 | |
|     number_type = utils.dtype_to_type(a.dtype)
 | |
|     return number_type(prims.item(a))
 | |
| 
 | |
| 
 | |
| # fast path when `to` returns an alias to input. This mimics the same function in aten
 | |
| def _to_will_alias(
 | |
|     a: TensorLikeType,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     copy: Optional[bool] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     memory_format: Optional[torch.memory_format] = None,
 | |
|     pin_memory: Optional[bool] = False,
 | |
|     non_blocking: bool = False,  # not using non_blocking
 | |
| ) -> bool:
 | |
|     return (
 | |
|         not copy
 | |
|         and (device is None or a.device == device)
 | |
|         and (dtype is None or a.dtype == dtype)
 | |
|         and (layout is None or a.layout == layout)
 | |
|         # is_pinned issue #84925
 | |
|         # and (pin_memory is None or pin_memory == a.is_pinned())
 | |
|         and (
 | |
|             memory_format is None
 | |
|             or memory_format == torch.preserve_format
 | |
|             or utils.is_contiguous_for_memory_format(a, memory_format=memory_format)
 | |
|         )
 | |
|     )
 | |
| 
 | |
| 
 | |
| @singledispatch
 | |
| def _to_dispatch(*args, **kwargs):
 | |
|     raise NotImplementedError
 | |
| 
 | |
| 
 | |
| @_to_dispatch.register
 | |
| def _to_device(
 | |
|     device: torch.device,
 | |
|     dtype: torch.dtype,
 | |
|     non_blocking: bool = False,
 | |
|     copy: bool = False,
 | |
|     memory_format: Optional[torch.memory_format] = None,
 | |
| ) -> dict[str, Any]:
 | |
|     kwargs = {
 | |
|         "device": device,
 | |
|         "dtype": dtype,
 | |
|         "non_blocking": non_blocking,
 | |
|         "copy": copy,
 | |
|         "memory_format": memory_format,
 | |
|     }
 | |
|     return kwargs
 | |
| 
 | |
| 
 | |
| @_to_dispatch.register
 | |
| def _to_device_str(
 | |
|     device: str,
 | |
|     dtype: torch.dtype,
 | |
|     non_blocking: bool = False,
 | |
|     copy: bool = False,
 | |
|     memory_format: Optional[torch.memory_format] = None,
 | |
| ) -> dict[str, Any]:
 | |
|     kwargs = {
 | |
|         "device": torch.device(device),
 | |
|         "dtype": dtype,
 | |
|         "non_blocking": non_blocking,
 | |
|         "copy": copy,
 | |
|         "memory_format": memory_format,
 | |
|     }
 | |
|     return kwargs
 | |
| 
 | |
| 
 | |
| @_to_dispatch.register
 | |
| def _to_dtype(
 | |
|     dtype: torch.dtype,
 | |
|     non_blocking: bool = False,
 | |
|     copy: bool = False,
 | |
|     memory_format: Optional[torch.memory_format] = None,
 | |
| ) -> dict[str, Any]:
 | |
|     kwargs = {
 | |
|         "dtype": dtype,
 | |
|         "non_blocking": non_blocking,
 | |
|         "copy": copy,
 | |
|         "memory_format": memory_format,
 | |
|     }
 | |
|     return kwargs
 | |
| 
 | |
| 
 | |
| @_to_dispatch.register
 | |
| def _to_other(
 | |
|     other: Tensor,
 | |
|     non_blocking: bool = False,
 | |
|     copy: bool = False,
 | |
|     memory_format: Optional[torch.memory_format] = None,
 | |
| ) -> dict[str, Any]:
 | |
|     device = other.device
 | |
|     dtype = other.dtype
 | |
|     layout = other.layout
 | |
|     # is_pinned issue #84925
 | |
|     # pin_memory = other.is_pinned()
 | |
|     kwargs = {
 | |
|         "device": device,
 | |
|         "dtype": dtype,
 | |
|         "layout": layout,
 | |
|         "non_blocking": non_blocking,
 | |
|         "copy": copy,
 | |
|         "memory_format": memory_format,
 | |
|     }
 | |
|     return kwargs
 | |
| 
 | |
| 
 | |
| # remove to_kwargs that is already present in `a`
 | |
| def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict):
 | |
|     options_to_check = ["dtype", "device", "layout", "memory_format"]
 | |
|     # "device" option could be passed a str instead torch.device
 | |
|     if "device" in to_kwargs and isinstance(to_kwargs["device"], str):
 | |
|         to_kwargs["device"] = torch.device(to_kwargs["device"])
 | |
| 
 | |
|     for kw in options_to_check:
 | |
|         if kw in to_kwargs:
 | |
|             if (
 | |
|                 (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format)
 | |
|                 or (
 | |
|                     kw == "device"
 | |
|                     and to_kwargs[kw].type == a.device.type
 | |
|                     and (
 | |
|                         not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index
 | |
|                     )
 | |
|                 )
 | |
|                 or (
 | |
|                     getattr(a, kw, None) == to_kwargs[kw]
 | |
|                 )  # this also handles {"memory_format": None}
 | |
|             ):
 | |
|                 to_kwargs.pop(kw)
 | |
| 
 | |
| 
 | |
| def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType:
 | |
|     # handled dispatch via positional arguments
 | |
|     if len(args) != 0:
 | |
|         kwargs = _to_dispatch(*args, **kwargs)
 | |
| 
 | |
|     # TODO: is_pinned is not currently supported in refs or fake_tensor
 | |
|     # https://github.com/pytorch/pytorch/issues/84925
 | |
|     assert "pin_memory" not in kwargs
 | |
|     _canonicalize_to_arguments(a, kwargs)
 | |
| 
 | |
|     if _to_will_alias(a, **kwargs):
 | |
|         return a
 | |
| 
 | |
|     copy = kwargs.pop("copy") if "copy" in kwargs else False
 | |
|     non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False
 | |
| 
 | |
|     # short-circuit to `prims.convert_element_type` when `to` is just a dtype change
 | |
|     if (
 | |
|         (copy or (kwargs.get("dtype", a.dtype) != a.dtype))
 | |
|         and (not non_blocking)
 | |
|         and ("memory_format" not in kwargs)
 | |
|         and ("device" not in kwargs)
 | |
|         and ("layout" not in kwargs)
 | |
|         # is_pinned issue #84925
 | |
|         # and ("pin_memory" not in kwargs)
 | |
|     ):
 | |
|         return prims.convert_element_type(a, kwargs.get("dtype", a.dtype))
 | |
| 
 | |
|     result = torch.empty_like(a, **kwargs)
 | |
|     # TODO: non_blocking should be handled by `copy_to`
 | |
|     copy_to(result, a)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| #
 | |
| # Reduction references
 | |
| #
 | |
| 
 | |
| 
 | |
| def _reduction(
 | |
|     a: TensorLikeType,
 | |
|     prim: Callable,
 | |
|     *,
 | |
|     has_identity: bool = True,
 | |
|     accepts_dim_tuple: bool = True,  # to handle min/argmin that accept single dim only
 | |
|     dims: Optional[DimsType] = None,
 | |
|     keepdims: bool = False,
 | |
|     dtype: Optional[torch.dtype] = None,  # should be specified for ops that support it
 | |
|     out: Optional[Tensor] = None,
 | |
|     output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
 | |
| ) -> TensorLikeType:  # it is usually SAME, but I want
 | |
|     # ref writers to actually think about what to put here
 | |
|     assert isinstance(a, TensorLike)
 | |
|     if a.ndim > 64:
 | |
|         raise RuntimeError(
 | |
|             f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!"
 | |
|         )
 | |
| 
 | |
|     if out is not None:
 | |
|         assert isinstance(out, TensorLike)
 | |
|         if dtype is not None:
 | |
|             # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
 | |
|             if dtype != out.dtype:
 | |
|                 raise RuntimeError(
 | |
|                     "dtype argument and out dtype must match in reduction"
 | |
|                 )
 | |
|     if not accepts_dim_tuple:
 | |
|         assert dims is None or isinstance(dims, Dim)
 | |
|     if isinstance(dims, Dim):
 | |
|         dims = (dims,)  # type: ignore[assignment]
 | |
|     dims = utils.reduction_dims(a.shape, dims)
 | |
|     if not has_identity:
 | |
|         valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims)
 | |
|         if not valid_shape:
 | |
|             raise RuntimeError(
 | |
|                 "reducing over zero-size dimension for reduction operation without identity"
 | |
|             )
 | |
|     computation_dtype, result_dtype = utils.reduction_dtypes(
 | |
|         a, output_dtype_kind, dtype
 | |
|     )
 | |
|     a = _maybe_convert_to_dtype(a, computation_dtype)  # type: ignore[method-assign]
 | |
|     result = prim(a, dims)
 | |
|     if keepdims:
 | |
|         output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
 | |
|         broadcast_dims = [i for i in range(a.ndim) if i not in dims]
 | |
|         result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)
 | |
| 
 | |
|     if out is not None:
 | |
|         assert result_dtype is not None
 | |
|         if dtype is not None and result_dtype != out.dtype:
 | |
|             raise RuntimeError(
 | |
|                 "Expected the dtype of reduction result and out to match"
 | |
|             )
 | |
|         out = _maybe_resize_out(out, result.shape)
 | |
|         return _safe_copy_out(copy_from=result, copy_to=out)  # type: ignore[arg-type]
 | |
| 
 | |
|     if result.dtype != result_dtype and result_dtype is not None:
 | |
|         result = prims.convert_element_type(result, result_dtype)
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def _make_copy_from_view(fn, return_none_on_out_variant=False):
 | |
|     """
 | |
|     Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
 | |
|     """
 | |
|     aten_fn = getattr(aten, fn.__name__)
 | |
|     annotations = getattr(fn, "__annotations__", {})
 | |
|     # view ops should not change dtypes, this ensures that the decomp path has
 | |
|     # the same error checks as eager.
 | |
|     fn = out_wrapper(exact_dtype=True)(aten_fn)
 | |
| 
 | |
|     @wraps(fn)
 | |
|     def _fn(*args, out=None, **kwargs):
 | |
|         result = fn(*args, out=out, **kwargs)
 | |
|         if return_none_on_out_variant and out is not None:
 | |
|             return None
 | |
|         if out is not None:
 | |
|             return result
 | |
| 
 | |
|         return pytree.tree_map(
 | |
|             lambda x: x.clone(memory_format=torch.contiguous_format),
 | |
|             result,
 | |
|         )
 | |
| 
 | |
|     copy_name = f"{fn.__name__}_copy"
 | |
|     _fn.__name__ = copy_name
 | |
|     _fn.__annotations__.update(annotations)
 | |
|     register_decomposition(getattr(aten, copy_name))(_fn)
 | |
|     return _fn
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.all)
 | |
| @out_wrapper()
 | |
| def all(
 | |
|     a: TensorLikeType,
 | |
|     dim: Optional[DimsType] = None,
 | |
|     keepdim: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
 | |
| 
 | |
|     if a.dtype == torch.uint8:
 | |
|         result = result.to(dtype=torch.uint8)
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.any)
 | |
| @out_wrapper()
 | |
| def any(
 | |
|     a: TensorLikeType,
 | |
|     dim: Optional[DimsType] = None,
 | |
|     keepdim: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     a_ = _maybe_convert_to_dtype(a, torch.bool)
 | |
|     if isinstance(dim, (list, tuple)) and len(dim) == 0:
 | |
|         result = a_.clone()
 | |
|     else:
 | |
|         result = a_.sum(dim=dim, keepdim=keepdim).ne(False)
 | |
| 
 | |
|     # Preserves uint8 -- probably a legacy mask thing
 | |
|     if a.dtype is torch.uint8:
 | |
|         return prims.convert_element_type(result, torch.uint8)
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out])
 | |
| def sum(
 | |
|     a: TensorLikeType,
 | |
|     dim: Union[Optional[int], Optional[list[int]]] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> TensorLikeType:
 | |
|     if dtype is None:
 | |
|         if out is not None:
 | |
|             dtype = out.dtype
 | |
|         elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
 | |
|             dtype = torch.int64
 | |
|         else:
 | |
|             dtype = a.dtype
 | |
|     # reduces over all dimensions if dim=() is passed
 | |
|     if dim == () or dim == []:
 | |
|         dim = None
 | |
|     return _reduction(
 | |
|         a,
 | |
|         prims.sum,
 | |
|         dims=dim,
 | |
|         keepdims=keepdim,
 | |
|         dtype=dtype,
 | |
|         out=out,
 | |
|         output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def sum_to_size(
 | |
|     a: Tensor,
 | |
|     *shape,
 | |
| ) -> Tensor:
 | |
|     shape = utils.extract_shape_from_varargs(shape, validate=False)
 | |
|     torch._check(
 | |
|         utils.is_expandable_to(shape, a.shape),
 | |
|         lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
 | |
|     )
 | |
|     # In ATen scalar tensors are sent through sum and the result is returned as
 | |
|     # type promoted
 | |
|     if utils.is_same_shape(shape, a.shape) and len(shape) > 0:
 | |
|         return prims.view_of(a)
 | |
|     leading_dims = a.ndim - len(shape)
 | |
|     reduce_dims = tuple(range(leading_dims)) + tuple(
 | |
|         i
 | |
|         for i in range(leading_dims, len(shape))
 | |
|         if shape[i - leading_dims] == 1 and a.shape[i] != 1
 | |
|     )
 | |
|     return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.prod)
 | |
| def prod(
 | |
|     a: TensorLikeType,
 | |
|     dim: Union[Optional[int], Optional[list[int]]] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     dtype=None,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> TensorLikeType:
 | |
|     if dtype is None:
 | |
|         if out is not None:
 | |
|             dtype = out.dtype
 | |
|         elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
 | |
|             dtype = torch.int64
 | |
|         else:
 | |
|             dtype = a.dtype
 | |
|     # reduces over all dimensions if dim=() is passed
 | |
|     if dim == () or dim == []:
 | |
|         dim = None
 | |
|     return _reduction(
 | |
|         a,
 | |
|         prims.prod,
 | |
|         dims=dim,
 | |
|         keepdims=keepdim,
 | |
|         dtype=dtype,
 | |
|         out=out,
 | |
|         output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.amin)
 | |
| def amin(
 | |
|     a: TensorLikeType,
 | |
|     dim: Optional[DimsType] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> TensorLikeType:
 | |
|     # reduces over all dimensions if dim=() is passed
 | |
|     if dim == () or dim == []:
 | |
|         dim = None
 | |
| 
 | |
|     return _reduction(
 | |
|         a,
 | |
|         prims.amin,
 | |
|         dims=dim,
 | |
|         keepdims=keepdim,
 | |
|         dtype=None,
 | |
|         out=out,
 | |
|         has_identity=False,
 | |
|         output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.amax)
 | |
| def amax(
 | |
|     a: TensorLikeType,
 | |
|     dim: Optional[DimsType] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> TensorLikeType:
 | |
|     # reduces over all dimensions if dim=() is passed
 | |
|     if dim == () or dim == []:
 | |
|         dim = None
 | |
| 
 | |
|     return _reduction(
 | |
|         a,
 | |
|         prims.amax,
 | |
|         dims=dim,
 | |
|         keepdims=keepdim,
 | |
|         dtype=None,
 | |
|         out=out,
 | |
|         has_identity=False,
 | |
|         output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _dim_var_dispatch(dim=None, unbiased=None):
 | |
|     # There's the following overload of torch.var:
 | |
|     # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
 | |
|     # We need to explicitly convert bool dims to unbiased arg
 | |
|     if unbiased is None and isinstance(dim, bool):
 | |
|         unbiased = dim
 | |
|         dim = None
 | |
|     return dim, unbiased
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.var)
 | |
| @out_wrapper()
 | |
| def var(
 | |
|     a: TensorLikeType,
 | |
|     dim: Optional[DimsType] = None,
 | |
|     unbiased: Optional[bool] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     correction: Optional[NumberType] = None,
 | |
| ) -> TensorLikeType:
 | |
|     dim, unbiased = _dim_var_dispatch(dim, unbiased)
 | |
|     correction = utils.set_correction(unbiased, correction)
 | |
|     # reduces over all dimensions if dim=() is passed
 | |
|     if dim == () or dim == []:
 | |
|         dim = None
 | |
| 
 | |
|     result = _reduction(
 | |
|         a,
 | |
|         partial(prims.var, correction=correction),
 | |
|         dims=dim,
 | |
|         keepdims=keepdim,
 | |
|         dtype=None,
 | |
|         out=None,
 | |
|         has_identity=True,
 | |
|         output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
 | |
|     )
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.std)
 | |
| @out_wrapper()
 | |
| def std(
 | |
|     a: TensorLikeType,
 | |
|     dim: Union[Optional[int], Optional[list[int]]] = None,
 | |
|     unbiased: Optional[bool] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     correction: Optional[NumberType] = None,
 | |
| ) -> TensorLikeType:
 | |
|     dim, unbiased = _dim_var_dispatch(dim, unbiased)
 | |
|     correction = utils.set_correction(unbiased, correction)
 | |
| 
 | |
|     opmath_dtype, dtype = utils.reduction_dtypes(
 | |
|         a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
 | |
|     )
 | |
|     a = _maybe_convert_to_dtype(a, opmath_dtype)
 | |
|     a_var = torch.var(a, dim, correction=correction, keepdim=keepdim)
 | |
|     a_std = torch.sqrt(a_var)
 | |
|     assert dtype is not None
 | |
|     return _maybe_convert_to_dtype(a_std, dtype)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.mean)
 | |
| def mean(
 | |
|     a: TensorLikeType,
 | |
|     dim: Optional[DimsType] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     dtype=None,
 | |
|     out=None,
 | |
| ) -> TensorLikeType:
 | |
|     # reduces over all dimensions if dim=() is passed
 | |
|     if dim == () or dim == []:
 | |
|         dim = None
 | |
|     orig_dtype = dtype
 | |
|     if dtype is None:
 | |
|         dtype = a.dtype
 | |
|     result = _reduction(
 | |
|         a,
 | |
|         prims.sum,
 | |
|         dims=dim,
 | |
|         keepdims=keepdim,
 | |
|         dtype=dtype,
 | |
|         out=None,
 | |
|         output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
 | |
|     )
 | |
|     torch._check(
 | |
|         utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
 | |
|         lambda: (
 | |
|             f"mean(): could not infer output dtype. "
 | |
|             f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either "
 | |
|             f"a floating point or complex dtype. Got: {dtype}"
 | |
|         ),
 | |
|     )
 | |
|     if isinstance(dim, Dim):
 | |
|         dim = (dim,)  # type: ignore[assignment]
 | |
|     dims = utils.reduction_dims(a.shape, dim)  # type: ignore[arg-type]
 | |
|     nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
 | |
|     result = true_divide(result, nelem)
 | |
|     result_dtype = a.dtype if dtype is None else dtype
 | |
|     result = _maybe_convert_to_dtype(result, result_dtype)  # type: ignore[method-assign]
 | |
|     if out is not None:
 | |
|         assert isinstance(out, TensorLike)
 | |
|         out = _maybe_resize_out(out, result.shape)
 | |
|         return _safe_copy_out(copy_from=result, copy_to=out)  # type: ignore[arg-type]
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.std_mean)
 | |
| @out_wrapper("out0", "out1")
 | |
| def std_mean(
 | |
|     a: TensorLikeType,
 | |
|     dim: Optional[DimsType] = None,
 | |
|     *,
 | |
|     unbiased: Optional[bool] = None,
 | |
|     keepdim: bool = False,
 | |
|     correction: Optional[NumberType] = None,
 | |
| ):
 | |
|     dim, unbiased = _dim_var_dispatch(dim, unbiased)
 | |
|     correction = utils.set_correction(unbiased, correction)
 | |
|     opmath_dtype, dtype = utils.reduction_dtypes(
 | |
|         a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
 | |
|     )
 | |
|     original_dtype = a.dtype
 | |
|     a = _maybe_convert_to_dtype(a, opmath_dtype)
 | |
|     a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim)
 | |
|     a_std = torch.sqrt(a_var)
 | |
|     assert dtype is not None
 | |
|     return (
 | |
|         _maybe_convert_to_dtype(a_std, dtype),
 | |
|         _maybe_convert_to_dtype(a_mean, original_dtype),
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.var_mean)
 | |
| @out_wrapper("out0", "out1")
 | |
| def var_mean(
 | |
|     a: TensorLikeType,
 | |
|     dim: Optional[DimsType] = None,
 | |
|     unbiased: Optional[bool] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     correction: Optional[NumberType] = None,
 | |
| ):
 | |
|     dim, unbiased = _dim_var_dispatch(dim, unbiased)
 | |
|     v = var(a, dim, unbiased, keepdim, correction=correction)
 | |
|     m = mean(a, dim, keepdim)
 | |
|     return v, m
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.addr)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self", "vec1", "vec2"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def addr(
 | |
|     self: TensorLikeType,
 | |
|     vec1: TensorLikeType,
 | |
|     vec2: TensorLikeType,
 | |
|     *,
 | |
|     beta: NumberType = 1,
 | |
|     alpha: NumberType = 1,
 | |
| ) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         vec1.ndim == 1,
 | |
|         lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
 | |
|     )
 | |
|     torch._check(
 | |
|         vec2.ndim == 1,
 | |
|         lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
 | |
|     )
 | |
|     for arg, arg_name in ((alpha, "alpha"), (beta, "beta")):
 | |
|         if isinstance(arg, bool):
 | |
|             torch._check(
 | |
|                 utils.is_boolean_dtype(self.dtype)
 | |
|                 and utils.is_boolean_dtype(vec1.dtype)
 | |
|                 and utils.is_boolean_dtype(vec2.dtype),
 | |
|                 lambda: f"Boolean {arg_name} only supported for Boolean results.",
 | |
|             )
 | |
|     self = self.expand(vec1.shape[0], vec2.shape[0])
 | |
|     if utils.is_boolean_dtype(self.dtype):
 | |
|         # Integers are accepted for booleans
 | |
|         torch._check(
 | |
|             is_weakly_lesser_type(type(beta), int),
 | |
|             lambda: f"expected bool/int beta but got {type(beta)}",
 | |
|         )
 | |
|         torch._check(
 | |
|             is_weakly_lesser_type(type(alpha), int),
 | |
|             lambda: f"expected bool/int alpha but got {type(beta)}",
 | |
|         )
 | |
|         if not beta:
 | |
|             return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False)
 | |
|         else:
 | |
|             return torch.logical_or(
 | |
|                 self,
 | |
|                 torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
 | |
|             )
 | |
|     else:
 | |
|         torch._check(
 | |
|             is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
 | |
|             lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
 | |
|         )
 | |
|         torch._check(
 | |
|             is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
 | |
|             lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
 | |
|         )
 | |
|         if beta == 0:
 | |
|             # This means NaNs from self are dropped if beta is zero
 | |
|             return alpha * torch.outer(vec1, vec2)
 | |
|         else:
 | |
|             return beta * self + alpha * torch.outer(vec1, vec2)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def atleast_1d(
 | |
|     arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
 | |
| ) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
 | |
|     """Reference implementation of :func:`torch.atleast_1d`."""
 | |
|     if not args and isinstance(arg, collections.abc.Sequence):
 | |
|         args_ = arg
 | |
|     else:
 | |
|         assert not isinstance(arg, collections.abc.Sequence)
 | |
|         args_ = (arg,) + args
 | |
|     res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
 | |
|     return res if len(res) > 1 else res[0]
 | |
| 
 | |
| 
 | |
| # Helper function with assert to avoid MyPy error
 | |
| # of incompatible type passed to unsqueeze
 | |
| def _unsqueeze_atleast(
 | |
|     at_least_fn: Callable, dim: int, arg: TensorLikeType
 | |
| ) -> TensorLikeType:
 | |
|     arg_ = at_least_fn(arg)
 | |
|     assert isinstance(arg_, TensorLike)
 | |
|     return unsqueeze(arg_, dim)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def atleast_2d(
 | |
|     arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
 | |
| ) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
 | |
|     """Reference implementation of :func:`torch.atleast_2d`."""
 | |
|     if not args and isinstance(arg, collections.abc.Sequence):
 | |
|         args_ = arg
 | |
|     else:
 | |
|         assert not isinstance(arg, collections.abc.Sequence)
 | |
|         args_ = (arg,) + args
 | |
|     unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
 | |
|     res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
 | |
|     return res if len(res) > 1 else res[0]
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def atleast_3d(
 | |
|     arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
 | |
| ) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
 | |
|     """Reference implementation of :func:`torch.atleast_3d`."""
 | |
|     if not args and isinstance(arg, collections.abc.Sequence):
 | |
|         args_ = arg
 | |
|     else:
 | |
|         assert not isinstance(arg, collections.abc.Sequence)
 | |
|         args_ = (arg,) + args
 | |
|     unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
 | |
|     res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
 | |
|     return res if len(res) > 1 else res[0]
 | |
| 
 | |
| 
 | |
| def as_strided(
 | |
|     a: TensorLikeType,
 | |
|     size: ShapeType,
 | |
|     stride: StrideType,
 | |
|     storage_offset: Optional[int] = None,
 | |
| ) -> TensorLikeType:
 | |
|     storage_offset_int = (
 | |
|         storage_offset if storage_offset is not None else a.storage_offset()
 | |
|     )
 | |
|     return prims.as_strided(a, size, stride, storage_offset_int)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.as_strided_scatter)
 | |
| @out_wrapper()
 | |
| def as_strided_scatter(
 | |
|     input: TensorLikeType,
 | |
|     src: TensorLikeType,
 | |
|     size: ShapeType,
 | |
|     stride: StrideType,
 | |
|     storage_offset: Optional[int] = None,
 | |
| ) -> TensorLikeType:
 | |
|     storage_offset_int = 0 if storage_offset is None else storage_offset
 | |
|     return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)
 | |
| 
 | |
| 
 | |
| def broadcast_shapes(*shapes) -> ShapeType:
 | |
|     return torch.Size(_broadcast_shapes(*shapes))
 | |
| 
 | |
| 
 | |
| @aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
 | |
| def broadcast_tensors(*tensors) -> list[TensorLikeType]:
 | |
|     if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
 | |
|         tensors = tensors[0]
 | |
|     return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType:
 | |
|     start = len(size) - len(a.shape)
 | |
|     dims = tuple(range(start, len(a.shape) + start))
 | |
|     return prims.broadcast_in_dim(a, size, dims)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.cat)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("tensors",),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
 | |
| )
 | |
| def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
 | |
|     def cat_compute_output_memory_format(inputs):
 | |
|         format = None
 | |
|         for t in inputs:
 | |
|             f = utils.suggest_memory_format(t)
 | |
|             if f == torch.contiguous_format:
 | |
|                 return f
 | |
|             if format is not None and format != f:
 | |
|                 return torch.contiguous_format
 | |
|             format = f
 | |
|         assert format is not None
 | |
|         return format
 | |
| 
 | |
|     if len(tensors) == 0:
 | |
|         msg = "cat expects at least one tensor, but received zero!"
 | |
|         raise ValueError(msg)
 | |
| 
 | |
|     for tensor in tensors:
 | |
|         assert isinstance(tensor, TensorLike)
 | |
| 
 | |
|     utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)
 | |
| 
 | |
|     from torch.fx.experimental.symbolic_shapes import (
 | |
|         guard_or_false,
 | |
|         guard_size_oblivious,
 | |
|     )
 | |
| 
 | |
|     # This is a bit tricky.  Naively, you would expect to just pick one
 | |
|     # arbitrary tensor and check that all tensors match this tensor.  However,
 | |
|     # there is legacy behavior which says that if you have a 1-D empty tensor
 | |
|     # (0,), this is permissible.  So you can't assume that all the tensors
 | |
|     # have same dimensionality, and you can't assume that the first tensor is
 | |
|     # the correct stencil.
 | |
|     #
 | |
|     # We'll implement this in a few passes.  First, we will try to infer the
 | |
|     # ndim of the cat output.  If this ndim != 1, then we know that all ndim =
 | |
|     # 1 inputs must be empty, or are errors.  If this ndim == 1, then life
 | |
|     # is easy (the legacy special case coincides with regular handling).
 | |
|     #
 | |
|     # NB: The regular implementation of cat just filters out empty inputs,
 | |
|     # but we do it slightly different here for better handling for unbacked
 | |
|     # SymInts
 | |
| 
 | |
|     example = None
 | |
|     for i, t in enumerate(tensors):
 | |
|         if example is None:
 | |
|             if t.ndim != 1:
 | |
|                 example = t
 | |
|         else:
 | |
|             if t.ndim != 1:
 | |
|                 torch._check(
 | |
|                     t.ndim == example.ndim,
 | |
|                     lambda: "Number of dimensions of tensors must match.  "
 | |
|                     f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for "
 | |
|                     f"tensor number {i} in the list",
 | |
|                 )
 | |
| 
 | |
|     if example is None:
 | |
|         # example is None if everything is 1-D.  If so, just arbitrarily pick
 | |
|         # the first one
 | |
|         example = tensors[0]
 | |
| 
 | |
|     shape = example.shape
 | |
|     filtered = []
 | |
|     for tensor_idx, tensor in enumerate(tensors):
 | |
|         if len(shape) != len(tensor.shape):
 | |
|             assert tensor.ndim == 1  # we've already checked this above
 | |
|             # Don't suggest the legacy behavior in the error message
 | |
|             torch._check(
 | |
|                 # NB: it is not enough to simply assert that tensor.shape[0] == 0;
 | |
|                 # this MUST be true even under guard size oblivious.
 | |
|                 # Effectively, we must actually know that the shape is zero,
 | |
|                 # passing an unbacked SymInt which we will defer a runtime
 | |
|                 # assert on won't cut it.  This is a policy decision (size
 | |
|                 # oblivious semantics say that u0 tensors never are inferred
 | |
|                 # to be zero size, even if they must be that for the cat to go
 | |
|                 # through), and is load bearing for our Inductor lowerings
 | |
|                 # (which assume that size oblivious tests are OK to determine
 | |
|                 # if a shape is permissibly zero.)
 | |
|                 guard_size_oblivious(tensor.shape[0] == 0),
 | |
|                 lambda: f"Number of dimensions of tensors must match.  "
 | |
|                 f"Expected {example.ndim}-D tensors, but got 1-D for "
 | |
|                 f"tensor number {tensor_idx} in the list",
 | |
|             )
 | |
|         else:
 | |
|             # Remove inputs that are 1-D, zero size
 | |
|             if tensor.ndim == 1 and guard_or_false(tensor.shape[0] == 0):
 | |
|                 continue
 | |
|             # Don't bother checking size match, prims.cat will handle it
 | |
|             filtered.append(tensor)
 | |
| 
 | |
|     memory_format = cat_compute_output_memory_format(tensors)
 | |
| 
 | |
|     if len(filtered) == 0:
 | |
|         t = tensors[0]
 | |
| 
 | |
|         # TODO: fix this to work with meta tensors
 | |
|         try:
 | |
|             # BUG? This looks like it wants to call builtins.any() but is
 | |
|             # actually calling .any() (in this file). Changing to builtins.any()
 | |
|             # causes tests to fail:
 | |
|             # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/test_ops.py -k \
 | |
|             #   TestFakeTensorCUDA.test_fake_crossref_backward_amp_cat_cuda_float32
 | |
|             requires_grad = bool(any(x.requires_grad for x in tensors))  # type: ignore[arg-type]
 | |
|         except Exception:
 | |
|             requires_grad = False  # type: ignore[assignment]
 | |
| 
 | |
|         return empty(
 | |
|             (0,),
 | |
|             dtype=t.dtype,
 | |
|             device=t.device,
 | |
|             requires_grad=requires_grad,
 | |
|             memory_format=memory_format,
 | |
|         )
 | |
| 
 | |
|     dim = utils.canonicalize_dim(filtered[0].ndim, dim)
 | |
|     utils.validate_idx(filtered[0].ndim, dim)
 | |
| 
 | |
|     return prims.cat(filtered, dim).clone(memory_format=memory_format)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @out_wrapper()
 | |
| def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
 | |
|     aligned_tensors = tuple(
 | |
|         x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
 | |
|     )
 | |
|     return cat(aligned_tensors, 1)
 | |
| 
 | |
| 
 | |
| def conj(input: TensorLikeType) -> TensorLikeType:
 | |
|     if not utils.is_complex_dtype(input.dtype):
 | |
|         return input
 | |
|     if input.is_sparse:
 | |
|         return torch.conj_physical(input)
 | |
|     return prims.conj(input)
 | |
| 
 | |
| 
 | |
| # This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp
 | |
| @register_decomposition(aten.constant_pad_nd)
 | |
| @out_wrapper()
 | |
| def constant_pad_nd(
 | |
|     input: TensorLikeType, pad: list[int], value: NumberType = 0
 | |
| ) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         len(pad) % 2 == 0,
 | |
|         lambda: f"Length of pad must be even but instead it equals {len(pad)}",
 | |
|     )
 | |
| 
 | |
|     input_sizes = input.shape
 | |
|     l_inp = len(input_sizes)
 | |
| 
 | |
|     l_pad = len(pad) // 2
 | |
|     l_diff = l_inp - l_pad
 | |
| 
 | |
|     torch._check(
 | |
|         l_inp >= l_pad,
 | |
|         lambda: "Length of pad should be no more than twice the number of "
 | |
|         f"dimensions of the input. Pad length is {len(pad)} while the input has "
 | |
|         f"{l_inp} dimensions.",
 | |
|     )
 | |
| 
 | |
|     c_input = input
 | |
|     for i in range(l_diff, l_inp):
 | |
|         pad_idx = 2 * (l_inp - i - 1)
 | |
|         if pad[pad_idx] < 0:
 | |
|             c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx])
 | |
| 
 | |
|         if pad[pad_idx + 1] < 0:
 | |
|             c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
 | |
| 
 | |
|     # If all the pads are negative we can return the result.
 | |
|     # Avoid early exiting if all pads = 0 to prevent specialization on export.
 | |
|     # During export, raw if statements are specialized on the input, meaning
 | |
|     # that we lose a branch depending on the example input used to export.
 | |
|     # Here, this is either the case where all pads = 0, or the case where at
 | |
|     # least one pad > 0 and the rest are >= 0.
 | |
|     # Avoiding the early exit when all pads = 0 ensures we can export
 | |
|     # constant_pad_nd for cases when all pads >= 0.
 | |
|     # Note: if any pads are negative, this code specializes due to the if statements above.
 | |
|     if builtins.all(p < 0 for p in pad):
 | |
|         return c_input.clone()
 | |
| 
 | |
|     new_shape = list(input_sizes[:l_diff])
 | |
| 
 | |
|     for i in range(l_pad):
 | |
|         pad_idx = len(pad) - ((i + 1) * 2)
 | |
|         new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
 | |
|         torch._check(
 | |
|             new_dim > 0,
 | |
|             lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
 | |
|             f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
 | |
|             f"which is invalid. Check dimension {l_diff + i} of your input.",
 | |
|         )
 | |
|         new_shape.append(new_dim)
 | |
| 
 | |
|     memory_format = utils.suggest_memory_format(input)
 | |
|     output = torch.empty(
 | |
|         new_shape,
 | |
|         dtype=input.dtype,
 | |
|         device=input.device,
 | |
|         requires_grad=input.requires_grad,
 | |
|         memory_format=memory_format,
 | |
|     )
 | |
| 
 | |
|     if value == 0 and input.dtype == torch.bool:
 | |
|         value = False
 | |
|     # torch.fill isn't typed to allow complex values
 | |
|     output = torch.fill(output, value)  # type: ignore[arg-type]
 | |
| 
 | |
|     c_output = output
 | |
|     for i in range(l_diff, l_inp):
 | |
|         pad_idx = 2 * (l_inp - i - 1)
 | |
|         if pad[pad_idx] >= 0:
 | |
|             c_output = c_output.narrow(
 | |
|                 i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
 | |
|             )
 | |
|         if pad[pad_idx + 1] >= 0:
 | |
|             c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])
 | |
| 
 | |
|     prims.copy_to(c_output, c_input)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def contiguous(
 | |
|     a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
 | |
| ) -> Tensor:
 | |
|     torch._check(
 | |
|         memory_format != torch.preserve_format,
 | |
|         lambda: "preserve memory format is unsupported by the contiguous operator",
 | |
|     )
 | |
| 
 | |
|     # TODO: make logic consistent with aten contiguous
 | |
|     if contiguous_for_memory_format_or_false(a, memory_format=memory_format):
 | |
|         return a
 | |
| 
 | |
|     return torch.clone(a, memory_format=memory_format)
 | |
| 
 | |
| 
 | |
| @out_wrapper()
 | |
| def dstack(tensors: TensorSequenceType) -> TensorLikeType:
 | |
|     torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
 | |
|     aligned_tensors = atleast_3d(*tensors)
 | |
|     return cat(aligned_tensors, 2)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.expand)
 | |
| def expand(a: Tensor, *shape) -> Tensor:
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_or_false, sym_or
 | |
| 
 | |
|     # NOTE: cannot use utils.extract_shape_from_varargs here
 | |
|     # because that also validates the shape, but the shape
 | |
|     # given to expand may be "invalid"
 | |
|     if len(shape) == 1 and isinstance(shape[0], Sequence):
 | |
|         shape = tuple(shape[0])
 | |
| 
 | |
|     torch._check(
 | |
|         len(shape) >= len(a.shape),
 | |
|         lambda: "expand: the requested shape has too few dimensions!",
 | |
|     )
 | |
| 
 | |
|     offset = len(shape) - len(a.shape)
 | |
|     shape_ = list(shape)
 | |
|     for idx, x in enumerate(a.shape):
 | |
|         offset_idx = idx + offset
 | |
|         requested_length = shape[offset_idx]
 | |
| 
 | |
|         # expand(in -> out) has 3 different semantics:
 | |
|         # 1) out == -1 -> size = in, stride unchanged
 | |
|         # 2) in == 1 -> size = out, stride = 0
 | |
|         # 3) in == out -> size = in, stride unchanged
 | |
|         #
 | |
|         # the code below is written for unbacked semantics s.t. we assume unbacked symbols don't
 | |
|         # represent -1 unless explicitly specified, and the user is opting for case 2) or 3).
 | |
|         # the sym_or allows either case, but in the decomposition's current state, broadcast_in_dim()
 | |
|         # will either assume case 3) (via validate_shape() marking the expanded shape size-like), or will
 | |
|         # raise a data-dependent error trying to figure out if the stride is 0, requiring the user to manually
 | |
|         # select between the semantics of cases 2) and 3).
 | |
|         if guard_or_false(requested_length == -1):
 | |
|             shape_[offset_idx] = x
 | |
|         else:
 | |
|             torch._check(
 | |
|                 sym_or(x == 1, requested_length == x),
 | |
|                 lambda: f"expand: attempting to expand a dimension of length {x} -> {requested_length}!",
 | |
|             )
 | |
|             torch._check(requested_length >= 0)
 | |
|             shape_[offset_idx] = requested_length
 | |
| 
 | |
|     # At this point shape must be valid
 | |
|     utils.validate_shape(shape_)
 | |
| 
 | |
|     return prims.broadcast_in_dim(
 | |
|         a, shape_, tuple(range(offset, len(a.shape) + offset))
 | |
|     )
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def expand_as(a: Tensor, b: Tensor) -> Tensor:
 | |
|     return a.expand(b.shape)
 | |
| 
 | |
| 
 | |
| def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> tuple[TensorLikeType, ...]:
 | |
|     if chunks <= 0:
 | |
|         msg = f"Expected at least one chunk, but got {chunks}!"
 | |
|         raise ValueError(msg)
 | |
| 
 | |
|     dim = utils.canonicalize_dim(a.ndim, dim)
 | |
|     length = a.shape[dim]
 | |
|     chunk_size = math.ceil(length / chunks)
 | |
|     full_chunks = math.floor(length / chunk_size)
 | |
|     tail_chunk_size = length % chunk_size
 | |
| 
 | |
|     result = [narrow(a, dim, i * chunk_size, chunk_size) for i in range(full_chunks)]
 | |
| 
 | |
|     if tail_chunk_size != 0:
 | |
|         result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
 | |
| 
 | |
|     return tuple(result)
 | |
| 
 | |
| 
 | |
| # Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless
 | |
| # a 0D tensor is flattened, in which case it's returned in 1D)
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
 | |
|     start_dim = utils.canonicalize_dim(a.ndim, start_dim)
 | |
|     end_dim = utils.canonicalize_dim(a.ndim, end_dim)
 | |
| 
 | |
|     # Short-circuits on no-op
 | |
|     if start_dim == end_dim and a.ndim != 0:
 | |
|         return a
 | |
| 
 | |
|     # Tries to take a view
 | |
|     # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
 | |
|     new_shape, _new_strides = prims._collapse_view_helper(a, start_dim, end_dim)
 | |
|     if new_shape is not None:
 | |
|         return prims.collapse_view(a, start_dim, end_dim)
 | |
| 
 | |
|     # Makes a copy if it can't make a view
 | |
|     return prims.collapse(a, start_dim, end_dim)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.flip)
 | |
| @out_wrapper()
 | |
| def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
 | |
|     if not isinstance(dims, tuple) and not isinstance(dims, list):
 | |
|         raise ValueError("dims has to be a sequence of ints")
 | |
|     dims = utils.canonicalize_dims(a.ndim, dims)  # type: ignore[assignment]
 | |
|     utils.validate_no_repeating_dims(dims)
 | |
|     return prims.rev(a, dims)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def fliplr(a: TensorLikeType) -> TensorLikeType:
 | |
|     if a.ndim < 2:
 | |
|         raise RuntimeError("Input must be >= 2-d.")
 | |
| 
 | |
|     return flip(a, (1,))
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def flipud(a: TensorLikeType) -> TensorLikeType:
 | |
|     if a.ndim < 1:
 | |
|         raise RuntimeError("Input must be >= 1-d.")
 | |
| 
 | |
|     return flip(a, (0,))
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def narrow(
 | |
|     a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int
 | |
| ) -> TensorLikeType:
 | |
|     # Supports Tensor overload that was added for XLA:
 | |
|     # https://github.com/pytorch/pytorch/issues/31558
 | |
|     if isinstance(start, TensorLike):
 | |
|         torch._check(
 | |
|             start.dim() == 0 and utils.is_integer_dtype(start.dtype),
 | |
|             lambda: "start must be an 0-dim integral Tensor.",
 | |
|         )
 | |
|         start = start.item()  # type: ignore[assignment]
 | |
|     start = cast(int, start)
 | |
|     torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
 | |
|     torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
 | |
|     dim = utils.canonicalize_dim(a.ndim, dim)
 | |
|     dim_length = a.size(dim)
 | |
|     torch._check_with(
 | |
|         IndexError,
 | |
|         -dim_length <= start and start <= dim_length,
 | |
|         lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})",
 | |
|     )
 | |
|     if start < 0:
 | |
|         start = start + dim_length
 | |
|     torch._check(
 | |
|         start <= dim_length - length,
 | |
|         lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
 | |
|     )
 | |
|     new_shape = list(a.shape)
 | |
|     new_shape[dim] = length
 | |
|     return a.as_strided(
 | |
|         new_shape, a.stride(), a.storage_offset() + a.stride(dim) * start
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _normalize(
 | |
|     a: Tensor, norm_dims: DimsType, eps: float
 | |
| ) -> tuple[Tensor, Tensor, Tensor]:
 | |
|     """Computes mean and 1/std of a tensor along norm_dims.
 | |
| 
 | |
|     Used as a helper function for normalization layers.
 | |
| 
 | |
|     Args:
 | |
|         a (Tensor): input tensor
 | |
|         norm_dims (DimsType): dimensions to normalize over
 | |
|         eps (float): epsilon for numerical stability
 | |
| 
 | |
|     Returns:
 | |
|         out (Tensor): normalized tensor.
 | |
|         mean (Tensor): mean of the tensor along norm_dims.
 | |
|         rstd (Tensor): 1/std of the tensor along norm_dims.
 | |
|     """
 | |
|     norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
 | |
|     computation_dtype = utils.get_computation_dtype(a.dtype)
 | |
|     a_acc = _maybe_convert_to_dtype(a, computation_dtype)
 | |
|     assert isinstance(a_acc, TensorLike)  # to avoid mypy error for var_mean
 | |
|     biased_var, mean = torch.var_mean(
 | |
|         a_acc, dim=norm_dims, unbiased=False, keepdim=True
 | |
|     )
 | |
|     rstd = torch.rsqrt(biased_var + eps)
 | |
|     out = (a_acc - mean) * rstd
 | |
|     return out, mean, rstd
 | |
| 
 | |
| 
 | |
| # add all specified dimensions
 | |
| def _unsqueeze_multiple(x: TensorLikeType, dimensions: list[int]) -> TensorLikeType:
 | |
|     for dim in sorted(dimensions):
 | |
|         x = torch.unsqueeze(x, dim)
 | |
|     return x
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.native_group_norm.default)
 | |
| def native_group_norm(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     batch_size: int,
 | |
|     num_channels: int,
 | |
|     flattened_inner_size: int,
 | |
|     num_groups: int,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor]:
 | |
|     torch._check(
 | |
|         input.ndim >= 2,
 | |
|         lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
 | |
|     )
 | |
|     torch._check(
 | |
|         num_channels % num_groups == 0,
 | |
|         lambda: "Expected number of channels in input to be divisible by num_groups, "
 | |
|         + f"but got input of shape {input.shape} and num_groups = {num_groups}",
 | |
|     )
 | |
| 
 | |
|     computation_dtype = utils.get_computation_dtype(input.dtype)
 | |
|     input_acc = _maybe_convert_to_dtype(input, computation_dtype)
 | |
|     # num_channels / num_groups and flattened inner dimension are the reduction axes
 | |
|     reduction_dims = [2, 3]
 | |
|     input_reshaped = torch.reshape(
 | |
|         input_acc,
 | |
|         [batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
 | |
|     )
 | |
|     reduction_dims = utils.canonicalize_dims(input_reshaped.ndim, reduction_dims)
 | |
|     biased_var, mean = torch.var_mean(
 | |
|         input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
 | |
|     )
 | |
|     rstd = torch.rsqrt(biased_var + eps)
 | |
|     if input.device.type == "cpu" and weight is not None:
 | |
|         weight_reshaped = torch.reshape(
 | |
|             weight, [1, num_groups, num_channels // num_groups, 1]
 | |
|         )
 | |
|         w = rstd * weight_reshaped
 | |
|         b = -mean * w
 | |
|         if bias is not None:
 | |
|             bias_reshaped = torch.reshape(
 | |
|                 bias, [1, num_groups, num_channels // num_groups, 1]
 | |
|             )
 | |
|             b = b + bias_reshaped
 | |
|         w = w.contiguous().as_strided([batch_size, num_channels], [num_channels, 1])
 | |
|         b = b.contiguous().as_strided([batch_size, num_channels], [num_channels, 1])
 | |
|         broadcast_dims = list(range(2, input.ndim))
 | |
|         unsqueeze_w = _unsqueeze_multiple(w, broadcast_dims)
 | |
|         unsqueeze_b = _unsqueeze_multiple(b, broadcast_dims)
 | |
|         out = input_acc * unsqueeze_w + unsqueeze_b
 | |
|     else:
 | |
|         out = (input_reshaped - mean) * rstd
 | |
|         out = out.view(input.shape)
 | |
|         broadcast_dims = [0] + list(range(2, input.ndim))
 | |
|         if weight is not None:
 | |
|             unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)
 | |
|             out = out * unsqueeze_weight
 | |
|         if bias is not None:
 | |
|             unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
 | |
|             out = out + unsqueeze_bias
 | |
| 
 | |
|     out = _maybe_convert_to_dtype(out, input.dtype)  # type: ignore[assignment]
 | |
|     mean = _maybe_convert_to_dtype(mean, input.dtype)  # type: ignore[assignment]
 | |
|     rstd = _maybe_convert_to_dtype(rstd, input.dtype)  # type: ignore[assignment]
 | |
| 
 | |
|     # remove broadcast dimensions from mean and rstd
 | |
|     mean = torch.squeeze(mean, reduction_dims)
 | |
|     rstd = torch.squeeze(rstd, reduction_dims)
 | |
|     return (out, mean, rstd)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.native_layer_norm)
 | |
| @out_wrapper("out0", "out1", "out2")
 | |
| def native_layer_norm(
 | |
|     input: Tensor,
 | |
|     normalized_shape: ShapeType,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor]:
 | |
|     from torch.fx.experimental.symbolic_shapes import sym_eq
 | |
| 
 | |
|     normalized_ndim = len(normalized_shape)
 | |
|     torch._check(
 | |
|         normalized_ndim >= 1,
 | |
|         lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
 | |
|         + "containing at least one element, but got normalized_shape = "
 | |
|         + str(normalized_shape),
 | |
|     )
 | |
|     # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
 | |
|     # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
 | |
|     # therefore we use tuple(normalized_shape)
 | |
|     torch._check(
 | |
|         weight is None or sym_eq(weight.shape, tuple(normalized_shape)),
 | |
|         lambda: "Expected weight to be of same shape as normalized_shape, but got "
 | |
|         + "weight of shape "
 | |
|         + str(weight.shape)  # type: ignore[union-attr]
 | |
|         + " and normalized_shape = "
 | |
|         + str(normalized_shape),
 | |
|     )
 | |
|     torch._check(
 | |
|         bias is None or sym_eq(bias.shape, tuple(normalized_shape)),
 | |
|         lambda: "Expected bias to be of same shape as normalized_shape, but got "
 | |
|         + "bias of shape "
 | |
|         + str(bias.shape)  # type: ignore[union-attr]
 | |
|         + " and normalized_shape = "
 | |
|         + str(normalized_shape),
 | |
|     )
 | |
|     torch._check(
 | |
|         input.ndim >= normalized_ndim
 | |
|         and sym_eq(
 | |
|             input.shape[(input.ndim - normalized_ndim) :], tuple(normalized_shape)
 | |
|         ),
 | |
|         lambda: "Given normalized_shape="
 | |
|         + str(normalized_shape)
 | |
|         + ", expected input with shape "
 | |
|         + str(normalized_shape)
 | |
|         + ", but got input of size "
 | |
|         + str(input.shape),
 | |
|     )
 | |
| 
 | |
|     input = contiguous(input)
 | |
|     if weight is not None:
 | |
|         weight = contiguous(weight)
 | |
|     if bias is not None:
 | |
|         bias = contiguous(bias)
 | |
| 
 | |
|     axis = input.ndim - normalized_ndim
 | |
|     reduction_dims = list(range(axis, input.ndim))
 | |
|     out, mean, rstd = _normalize(input, reduction_dims, eps)
 | |
| 
 | |
|     if weight is None and bias is not None:
 | |
|         out = out + bias
 | |
|     elif weight is not None and bias is None:
 | |
|         out = out * weight
 | |
|     elif weight is not None and bias is not None:
 | |
|         out = out * weight + bias
 | |
| 
 | |
|     out = _maybe_convert_to_dtype(out, input.dtype)  # type: ignore[assignment]
 | |
|     if input.device.type in ["cpu", "mtia"]:
 | |
|         mean = _maybe_convert_to_dtype(mean, input.dtype)  # type: ignore[assignment]
 | |
|         rstd = _maybe_convert_to_dtype(rstd, input.dtype)  # type: ignore[assignment]
 | |
|     return (out, mean, rstd)
 | |
| 
 | |
| 
 | |
| @torch._subclasses.fake_impls.register_op_impl(aten.native_layer_norm.default)
 | |
| def native_layer_norm_fake(fake_mode, func, *args, **kwargs):
 | |
|     return native_layer_norm(*args)
 | |
| 
 | |
| 
 | |
| # TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
 | |
| # test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
 | |
| @register_decomposition(aten.permute)
 | |
| def permute(a: TensorLikeType, *dims) -> TensorLikeType:
 | |
|     _permutation = utils.canonicalize_dims(
 | |
|         a.ndim, utils.extract_dims_from_varargs(dims)
 | |
|     )
 | |
|     return prims.transpose(a, _permutation)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.renorm)
 | |
| @out_wrapper()
 | |
| def renorm(
 | |
|     input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType
 | |
| ) -> TensorLikeType:
 | |
|     torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued")
 | |
|     torch._check(p > 0, lambda: "renorm: non-positive norm not supported")
 | |
|     torch._check(
 | |
|         not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued"
 | |
|     )
 | |
|     torch._check(
 | |
|         maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}"
 | |
|     )
 | |
|     ndim = input.ndim
 | |
|     torch._check(
 | |
|         ndim > 1,
 | |
|         lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions",
 | |
|     )
 | |
| 
 | |
|     dim = utils.canonicalize_dim(ndim, dim)
 | |
|     reduce_dims = list(range(ndim))
 | |
|     del reduce_dims[dim]
 | |
| 
 | |
|     # For half and bfloat16, calculate norm in float precision then cast
 | |
|     # normalization factor to half
 | |
|     acc_type = utils.get_computation_dtype(input.dtype)
 | |
|     if acc_type != input.dtype:
 | |
|         norm = torch.linalg.vector_norm(
 | |
|             input, p, reduce_dims, keepdim=True, dtype=acc_type
 | |
|         )
 | |
|     else:
 | |
|         norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True)
 | |
| 
 | |
|     eps = 1e-7
 | |
|     norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
 | |
|     if acc_type != input.dtype:
 | |
|         norm_factor = prims.convert_element_type(norm_factor, input.dtype)
 | |
|     return (input * norm_factor).contiguous()
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| def stft(
 | |
|     input: Tensor,
 | |
|     n_fft: int,
 | |
|     hop_length: Optional[int] = None,
 | |
|     win_length: Optional[int] = None,
 | |
|     window: Optional[Tensor] = None,
 | |
|     center: bool = True,
 | |
|     pad_mode: str = "reflect",
 | |
|     normalized: bool = False,
 | |
|     onesided: Optional[bool] = None,
 | |
|     return_complex: Optional[bool] = None,
 | |
|     align_to_window: Optional[bool] = None,
 | |
| ) -> Tensor:
 | |
|     torch._check(
 | |
|         window is None or window.device == input.device,
 | |
|         lambda: (
 | |
|             f"stft input and window must be on the same device but got self on {input.device}"
 | |
|             + f" and window on {window.device}"  # type: ignore[union-attr]
 | |
|         ),
 | |
|     )
 | |
|     torch._check(
 | |
|         not center or align_to_window is None,
 | |
|         "stft only supports align_to_window for center = False.",
 | |
|     )
 | |
| 
 | |
|     hop_length_ = hop_length if hop_length is not None else n_fft // 4
 | |
|     win_length_ = win_length if win_length is not None else n_fft
 | |
| 
 | |
|     if return_complex is None:
 | |
|         return_complex_ = input.is_complex() or (
 | |
|             window is not None and utils.is_complex_dtype(window.dtype)
 | |
|         )
 | |
|         torch._check(
 | |
|             return_complex_,
 | |
|             (
 | |
|                 "stft requires the return_complex parameter be given for real inputs, "
 | |
|                 + "and will further require that return_complex=True in a future PyTorch release."
 | |
|             ),
 | |
|         )
 | |
|     else:
 | |
|         return_complex_ = return_complex
 | |
| 
 | |
|     torch._check(
 | |
|         utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype),
 | |
|         lambda: "stft expected a tensor of floating point or complex values",
 | |
|     )
 | |
|     torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor")
 | |
| 
 | |
|     original_ndim = input.ndim
 | |
|     if original_ndim == 1:
 | |
|         input = input.unsqueeze(0)
 | |
| 
 | |
|     if center:
 | |
|         extra_dims = 3 - input.ndim
 | |
|         pad_amount = n_fft // 2
 | |
|         extended_shape = [*itertools.repeat(1, extra_dims), *input.shape]
 | |
|         input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode)
 | |
|         input = input.view(input.size()[extra_dims:])
 | |
| 
 | |
|     length = input.size(1)
 | |
|     torch._check(
 | |
|         0 < n_fft <= length,
 | |
|         lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}",
 | |
|     )
 | |
|     torch._check(
 | |
|         hop_length_ > 0,
 | |
|         lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}",
 | |
|     )
 | |
|     torch._check(
 | |
|         0 < win_length_ <= n_fft,
 | |
|         lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}",
 | |
|     )
 | |
|     torch._check(
 | |
|         window is None or window.shape == (win_length_,),
 | |
|         lambda: (
 | |
|             f"expected a 1D window tensor of size equal to win_length={win_length_}, "
 | |
|             + f"but got window with size {window.shape}"  # type: ignore[union-attr]
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     if win_length_ < n_fft:
 | |
|         if window is None:
 | |
|             window = torch.ones(win_length_, dtype=input.dtype, device=input.device)
 | |
|         left = (n_fft - win_length_) // 2
 | |
|         window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])
 | |
| 
 | |
|     if not center and align_to_window:
 | |
|         input_pad_amount = (n_fft - win_length_) // 2
 | |
|         input = aten.pad(input, [input_pad_amount, input_pad_amount], pad_mode)
 | |
| 
 | |
|     input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
 | |
| 
 | |
|     if window is not None:
 | |
|         input = input * window
 | |
| 
 | |
|     complex_fft = utils.is_complex_dtype(input.dtype)
 | |
|     onesided = onesided if onesided is not None else not complex_fft
 | |
|     norm = "ortho" if normalized else None
 | |
|     if onesided:
 | |
|         torch._check(
 | |
|             not complex_fft,
 | |
|             lambda: "Cannot have onesided output if window or input is complex",
 | |
|         )
 | |
|         out = torch.fft.rfft(input, dim=-1, norm=norm)
 | |
|     else:
 | |
|         out = torch.fft.fft(input, dim=-1, norm=norm)
 | |
| 
 | |
|     out.transpose_(1, 2)
 | |
| 
 | |
|     if original_ndim == 1:
 | |
|         out = out.squeeze_(0)
 | |
| 
 | |
|     return out if return_complex_ else torch.view_as_real(out)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| def istft(
 | |
|     input: Tensor,
 | |
|     n_fft: int,
 | |
|     hop_length: Optional[int] = None,
 | |
|     win_length: Optional[int] = None,
 | |
|     window: Optional[Tensor] = None,
 | |
|     center: bool = True,
 | |
|     normalized: bool = False,
 | |
|     onesided: Optional[bool] = None,
 | |
|     length: Optional[int] = None,
 | |
|     return_complex=False,
 | |
| ) -> Tensor:
 | |
|     torch._check(
 | |
|         window is None or window.device == input.device,
 | |
|         lambda: (
 | |
|             f"istft input and window must be on the same device but got self on {input.device}"
 | |
|             + f" and window on {window.device}"  # type: ignore[union-attr]
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     hop_length_ = hop_length if hop_length is not None else n_fft // 4
 | |
|     win_length_ = win_length if win_length is not None else n_fft
 | |
| 
 | |
|     torch._check(
 | |
|         utils.is_complex_dtype(input.dtype),
 | |
|         lambda: (
 | |
|             "istft input and window must be on the same device but got self on "
 | |
|             + f"{input.device} and window on {window.device}"  # type: ignore[union-attr]
 | |
|         ),
 | |
|     )
 | |
|     n_frames = input.size(-1)
 | |
|     fft_size = input.size(-2)
 | |
| 
 | |
|     expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1)
 | |
|     torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty")
 | |
|     torch._check(
 | |
|         2 <= input.ndim <= 3,
 | |
|         lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}",
 | |
|     )
 | |
|     onesided_ = onesided if onesided is not None else fft_size != n_fft
 | |
| 
 | |
|     if onesided_:
 | |
|         torch._check(
 | |
|             n_fft // 2 + 1 == fft_size,
 | |
|             lambda: (
 | |
|                 "istft expected the frequency dimension (3rd to the last) of the input tensor "
 | |
|                 + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}"
 | |
|             ),
 | |
|         )
 | |
|     else:
 | |
|         torch._check(
 | |
|             n_fft == fft_size,
 | |
|             lambda: (
 | |
|                 "istft expected the frequency dimension (3rd to the last) of the input tensor "
 | |
|                 + "to match n_fft when onesided=False, but got {fft_size}",
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     torch._check(
 | |
|         0 < hop_length_ <= win_length_,
 | |
|         lambda: "istft expected 0 < hop_length <= win_length",
 | |
|     )
 | |
|     torch._check(
 | |
|         0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft"
 | |
|     )
 | |
|     torch._check(
 | |
|         window is None or window.shape == (win_length_,),
 | |
|         lambda: "Invalid window shape. window has to be 1D and length of `win_length`",
 | |
|     )
 | |
| 
 | |
|     if window is None:
 | |
|         real_dtype = utils.corresponding_real_dtype(input.dtype)
 | |
|         window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device)
 | |
|     else:
 | |
|         window_ = window
 | |
| 
 | |
|     if win_length_ != n_fft:
 | |
|         left = (n_fft - win_length_) // 2
 | |
|         window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0)
 | |
| 
 | |
|     original_ndim = input.ndim
 | |
|     if input.ndim == 2:
 | |
|         input = input.unsqueeze(0)
 | |
| 
 | |
|     input = input.transpose(1, 2)
 | |
|     norm = "ortho" if normalized else None
 | |
|     if return_complex:
 | |
|         torch._check(
 | |
|             not onesided_,
 | |
|             lambda: "cannot have onesided output if window or input is complex",
 | |
|         )
 | |
|         input = torch.fft.ifft(input, dim=-1, norm=norm)
 | |
|     else:
 | |
|         torch._check(
 | |
|             window is None or not utils.is_complex_dtype(window.dtype),
 | |
|             lambda: "Complex windows are incompatible with return_complex=False",
 | |
|         )
 | |
|         if not onesided_:
 | |
|             input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1)
 | |
|         input = torch.fft.irfft(input, dim=-1, norm=norm)
 | |
| 
 | |
|     assert input.size(2) == n_fft
 | |
| 
 | |
|     y_tmp = input * window_.view([1, 1, n_fft])
 | |
|     y = aten.unfold_backward(
 | |
|         y_tmp,
 | |
|         input_sizes=(y_tmp.size(0), expected_output_signal_len),
 | |
|         dim=1,
 | |
|         size=n_fft,
 | |
|         step=hop_length_,
 | |
|     )
 | |
|     window_envelop = aten.unfold_backward(
 | |
|         window_.pow(2).expand((1, n_frames, n_fft)),
 | |
|         input_sizes=(y_tmp.size(0), expected_output_signal_len),
 | |
|         dim=1,
 | |
|         size=n_fft,
 | |
|         step=hop_length_,
 | |
|     )
 | |
| 
 | |
|     assert expected_output_signal_len == y.size(1)
 | |
|     assert expected_output_signal_len == window_envelop.size(1)
 | |
| 
 | |
|     start = n_fft // 2 if center else 0
 | |
|     if length is not None:
 | |
|         end = start + length
 | |
|     elif center:
 | |
|         end = expected_output_signal_len - n_fft // 2
 | |
|     else:
 | |
|         end = expected_output_signal_len
 | |
| 
 | |
|     length = max(0, end - start)
 | |
|     y = y.narrow(dim=1, start=start, length=length)
 | |
|     window_envelop = window_envelop.narrow(dim=1, start=start, length=length)
 | |
| 
 | |
|     y = y / window_envelop
 | |
|     if original_ndim == 2:
 | |
|         y = y.squeeze(0)
 | |
| 
 | |
|     if end > expected_output_signal_len:
 | |
|         warnings.warn(
 | |
|             "The length of signal is shorter than the length parameter. Result is being "
 | |
|             + "padded with zeros in the tail. Please check your center and hop_length settings"
 | |
|         )
 | |
|         y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0)
 | |
|     return y
 | |
| 
 | |
| 
 | |
| # Get the new shape and stride after applying unfold to an input tensor
 | |
| def _get_unfold_shape_stride(
 | |
|     a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int
 | |
| ):
 | |
|     a_ndim = len(a_shape)
 | |
|     dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True)
 | |
|     max_size = 1 if a_ndim == 0 else a_shape[dim]
 | |
|     last_stride = 1 if a_ndim == 0 else a_stride[dim]
 | |
| 
 | |
|     torch._check(
 | |
|         size <= max_size,
 | |
|         lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
 | |
|     )
 | |
| 
 | |
|     torch._check(
 | |
|         step > 0,
 | |
|         lambda: f"Step is {step} but must be > 0",
 | |
|     )
 | |
| 
 | |
|     shape = list(a_shape)
 | |
|     strides = list(a_stride)
 | |
|     shape.append(size)
 | |
|     strides.append(last_stride)
 | |
|     if dim < a_ndim:
 | |
|         shape[dim] = (shape[dim] - size) // step + 1
 | |
|         strides[dim] *= step
 | |
|     return shape, strides
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.repeat)
 | |
| @out_wrapper()
 | |
| def repeat(a: Tensor, *repeat_shape) -> Tensor:
 | |
|     repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
 | |
|     torch._check(
 | |
|         len(repeat_shape) >= len(a.shape),
 | |
|         lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
 | |
|     )
 | |
| 
 | |
|     if len(repeat_shape) == 0:
 | |
|         return torch.clone(a)
 | |
| 
 | |
|     num_new_dimensions = len(repeat_shape) - a.ndim
 | |
|     padded_shape = [1] * num_new_dimensions
 | |
|     for dim_size in a.shape:
 | |
|         padded_shape.append(dim_size)
 | |
| 
 | |
|     target_shape = tuple(
 | |
|         padded_size * repeat_size
 | |
|         for padded_size, repeat_size in zip(padded_shape, repeat_shape)
 | |
|     )
 | |
| 
 | |
|     # return an empty tensor if one of the repeat_shape dimensions is zero
 | |
|     if 0 in repeat_shape:
 | |
|         return torch.empty(
 | |
|             target_shape,
 | |
|             dtype=a.dtype,
 | |
|             device=a.device,
 | |
|             requires_grad=a.requires_grad,
 | |
|             memory_format=utils.suggest_memory_format(a),
 | |
|         )
 | |
| 
 | |
|     urtensor_shape = target_shape
 | |
|     urtensor_stride = utils.make_contiguous_strides_for(target_shape)
 | |
|     for dim, dim_size in enumerate(padded_shape):
 | |
|         # repeat each dimension by using unfold_copy operation
 | |
|         urtensor_shape, urtensor_stride = _get_unfold_shape_stride(
 | |
|             urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1)
 | |
|         )
 | |
| 
 | |
|     # derive permute order by sorting urtensor strides
 | |
|     enumerated_stride = list(enumerate(urtensor_stride))
 | |
|     enumerated_stride.sort(key=operator.itemgetter(1), reverse=True)
 | |
|     permute_order, _sorted_stride = zip(*enumerated_stride)
 | |
| 
 | |
|     # add new and expand dimensions according to urtensor
 | |
|     repeat_xtensor = a.expand(urtensor_shape)
 | |
| 
 | |
|     # clone tensor to concretize expanded dimensions
 | |
|     cloned_result = torch.clone(repeat_xtensor)
 | |
| 
 | |
|     # transpose axis so strides are in sorted order
 | |
|     permuted_result = cloned_result.permute(permute_order)
 | |
| 
 | |
|     # reshape to get contiguous tensor with correct target shape
 | |
|     return permuted_result.reshape(target_shape)
 | |
| 
 | |
| 
 | |
| def _reshape_view_helper_core_alg(
 | |
|     a: TensorLikeType, shape, allow_copy: bool
 | |
| ) -> TensorLikeType:
 | |
|     # NOTE [Reshape Algorithm]
 | |
|     # This algorithm works by attempting to greedily construct the desired dimensions in
 | |
|     # the output shape, left to right. It does this by, conceptually, accumulating
 | |
|     # dimensions of the original tensor, also left to right, until the dimension
 | |
|     # can be constructed using prims.split_dim.
 | |
|     # The algorithm also has special handling for tail squeezes/unsqueezes, like
 | |
|     # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
 | |
|     #
 | |
|     # This algorithm does not flatten the original tensor and then split dims as appropriate
 | |
|     # because that would create copies more often than this algorithm. flatten is the only
 | |
|     # operation below which can create a view or a copy, and while it prefers creating
 | |
|     # views it may sometimes create a copy if the tensor's strides do not permit a view.
 | |
|     # As a result, this algorithm tries to minimize flattening.
 | |
|     #
 | |
|     # Note that a better version of this algorithm may exist. Regions which could be
 | |
|     # flattened without creating a copy can be identified in advance, and that might
 | |
|     # allow fewer flatten calls or faster short-circuiting to make a copy.
 | |
|     idx = 0
 | |
|     a_ = a
 | |
|     for length in shape:
 | |
|         # Handles tail unsqueezes
 | |
|         if idx >= a_.ndim:
 | |
|             assert length == 1
 | |
|             last_dim = a_.ndim - 1
 | |
|             # NOTE: using split_dim instead of unsqueeze may seem silly here,
 | |
|             # but it's necessary to get the strides correct
 | |
|             a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim])
 | |
|             idx = idx + 1
 | |
|             continue
 | |
| 
 | |
|         # Skips dimensions that are already the correct length
 | |
|         if length == a_.shape[idx]:
 | |
|             idx = idx + 1
 | |
|             continue
 | |
| 
 | |
|         accum = a_.shape[idx]
 | |
|         end = idx
 | |
|         while accum % length != 0:
 | |
|             end += 1
 | |
|             accum *= a_.shape[end]
 | |
|         if end != idx:
 | |
|             # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
 | |
|             # This flattening is why reshape sometimes creates a copy -- because flattening
 | |
|             # may return a view of a copy
 | |
| 
 | |
|             # Checks if collapse can be a view and short-circuits to copying reshape if it can't
 | |
|             new_shape, _new_strides = prims._collapse_view_helper(a_, idx, end)
 | |
|             if new_shape is None:
 | |
|                 if allow_copy:
 | |
|                     return prims.reshape(a, shape)
 | |
| 
 | |
|                 msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
 | |
|                 raise ValueError(msg)
 | |
| 
 | |
|             a_ = flatten(a_, idx, end)
 | |
| 
 | |
|         # Splits the (possibly flattened) dimension to create the desired dim length.
 | |
|         # guard_or_true is safe due to the tail unsqueeze routine.
 | |
|         if accum != length:
 | |
|             a_ = prims.split_dim(a_, idx, length)
 | |
| 
 | |
|         idx = idx + 1
 | |
| 
 | |
|     # Squeezes tail
 | |
|     while idx < a_.ndim:
 | |
|         torch._check(
 | |
|             a_.shape[idx] == 1,
 | |
|             lambda: f"a.size({idx}) expected to be 1 but got {a_.shape[idx]}",
 | |
|         )
 | |
|         a_ = squeeze(a_, idx)
 | |
| 
 | |
|     if a_ is a:
 | |
|         return prims.view_of(a)
 | |
|     else:
 | |
|         return a_
 | |
| 
 | |
| 
 | |
| def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
 | |
|     # Creates a valid shape
 | |
|     shape = utils.extract_shape_from_varargs(shape, validate=False)
 | |
|     # Reshape may be given a shape with a -1 length
 | |
|     # This indicates that the dimension's length should be inferred
 | |
|     shape = utils.infer_size(shape, a.numel())
 | |
| 
 | |
|     # Special-cases tensors with no elements
 | |
|     if a.numel() == 0:
 | |
|         return as_strided(a, shape, utils.make_contiguous_strides_for(shape))
 | |
| 
 | |
|     # Special-cases reshaping zero dim tensors
 | |
|     if a.ndim == 0:
 | |
|         _a = a
 | |
|         for length in shape:
 | |
|             assert length == 1
 | |
|             _a = unsqueeze(_a, -1)
 | |
|         if _a is a:
 | |
|             return prims.view_of(a)
 | |
|         else:
 | |
|             return _a
 | |
| 
 | |
|     # Special-cases reshaping to zero dim tensors
 | |
|     if len(shape) == 0:
 | |
|         _a = a
 | |
|         for length in a.shape:
 | |
|             assert length == 1
 | |
|             _a = squeeze(_a, -1)
 | |
|         if _a is a:
 | |
|             return prims.view_of(a)
 | |
|         else:
 | |
|             return _a
 | |
| 
 | |
|     if is_contiguous_or_false(a):
 | |
|         # Special-cases for nd_to_1d
 | |
|         if len(shape) == 1 and a.ndim > 1:
 | |
|             return torch.as_strided(a, [a.numel()], [1])
 | |
|         # Special-cases for 1d_to_2d
 | |
|         if len(shape) == 2 and a.ndim == 1:
 | |
|             dim0 = shape[0]
 | |
|             dim1 = shape[1]
 | |
|             return torch.as_strided(a, [dim0, dim1], [dim1, 1])
 | |
| 
 | |
|     shape_numel = reduce(operator.mul, shape, 1)
 | |
|     torch._check(
 | |
|         a.numel() == shape_numel,
 | |
|         f"Could not reshape a tensor with shape {a.shape} as a tensor with shape {shape}!",
 | |
|     )
 | |
| 
 | |
|     # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape
 | |
|     return _reshape_view_helper_core_alg(a, shape, allow_copy)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| # NOTE: shape is a vararg because Tensor.reshape can be called with as
 | |
| # Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call
 | |
| # torch.reshape doesn't support unpacked shapes
 | |
| def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
 | |
|     return _reshape_view_helper(a, *shape, allow_copy=True)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
 | |
|     return self.reshape(other.size())
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.roll)
 | |
| @out_wrapper()
 | |
| def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType:
 | |
|     """Reference implementation of :func:`torch.roll`."""
 | |
|     dims = utils.canonicalize_dims(a.ndim, dims)
 | |
|     # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
 | |
|     if not isinstance(shifts, Iterable):
 | |
|         shifts = (shifts,)
 | |
|     if not isinstance(dims, Iterable):
 | |
|         dims = (dims,)
 | |
| 
 | |
|     # Avoid modulo by zero
 | |
|     if a.numel() == 0:
 | |
|         # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
 | |
|         return a.clone()
 | |
| 
 | |
|     if a.dim() == 0 and len(dims) > 0:
 | |
|         raise IndexError(
 | |
|             f"Dimension specified as {dims[0]} but tensor has no dimensions"
 | |
|         )
 | |
| 
 | |
|     len_shifts = len(shifts)
 | |
|     len_dims = len(dims)
 | |
|     if len_shifts != 1 or len_dims != 1:
 | |
|         if len_shifts == 0:
 | |
|             raise RuntimeError("`shifts` required")
 | |
|         # Takes care of the case when dims is not specified (default)
 | |
|         # By default, the tensor is flattened before shifting, after which the original shape is restored
 | |
|         if len_dims == 0 and len_shifts == 1:
 | |
|             return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
 | |
|         if len_shifts != len_dims:
 | |
|             raise RuntimeError(
 | |
|                 f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
 | |
|             )
 | |
|         assert len_dims > 1
 | |
|         tail_shifts = shifts[1:]
 | |
|         tail_dims = dims[1:]
 | |
|         first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
 | |
|         return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
 | |
| 
 | |
|     # This path is taken when only one dimension is rolled
 | |
|     # For example to get `first_dim_rolled` above
 | |
|     dim = dims[0]
 | |
|     size = a.shape[dim]
 | |
|     start = (size - shifts[0]) % size
 | |
|     idx = torch.arange(size, device=a.device)
 | |
|     return a.index_select(dim, torch.fmod(start + idx, size))
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.rot90)
 | |
| @out_wrapper()
 | |
| def rot90(
 | |
|     a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
 | |
| ) -> TensorLikeType:
 | |
|     """Reference implementation of :func:`torch.rot90`."""
 | |
|     if len(dims) != 2:
 | |
|         raise RuntimeError(
 | |
|             f"expected total rotation dims == 2, but got dims = {len(dims)}"
 | |
|         )
 | |
|     if a.ndim < 2:
 | |
|         raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")
 | |
| 
 | |
|     # Do this after the initial checks to be compatible with the behavior in
 | |
|     # core.
 | |
|     dims = utils.canonicalize_dims(a.ndim, dims)
 | |
| 
 | |
|     if dims[0] == dims[1]:
 | |
|         raise RuntimeError(
 | |
|             f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
 | |
|         )
 | |
|     k = k % 4  # Rotation direction is from the second towards the first axis for k < 0
 | |
|     if k == 1:
 | |
|         return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1])
 | |
|     elif k == 2:
 | |
|         return torch.flip(a, dims)
 | |
|     elif k == 3:
 | |
|         return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1])
 | |
|     else:
 | |
|         return a.clone(memory_format=torch.contiguous_format)
 | |
| 
 | |
| 
 | |
| def _check_stack_inputs(tensors: TensorSequenceType) -> None:
 | |
|     entry_shape = tensors[0].shape
 | |
|     for i in range(1, len(tensors)):
 | |
|         assert tensors[i].shape == entry_shape, (
 | |
|             f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
 | |
|             f"and {tensors[i].shape} at entry {i}"
 | |
|         )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.stack)
 | |
| @out_wrapper()
 | |
| def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
 | |
|     assert len(tensors) > 0, "stack expects a non-empty TensorList"
 | |
|     wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim)
 | |
|     # Refs need sparse support to check other condition
 | |
|     if wrapped_dim < tensors[0].ndim:  # and not tensors[0].is_sparse:
 | |
|         _check_stack_inputs(tensors)
 | |
|         result_sizes = list(tensors[0].shape)
 | |
|         result_sizes.insert(wrapped_dim, len(tensors))
 | |
|         out = torch.cat(tensors, wrapped_dim)
 | |
|         return out.view(result_sizes)
 | |
| 
 | |
|     # If dim == tensors[0].ndim, view cannot efficiently handle it
 | |
|     return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @out_wrapper()
 | |
| def softmax(
 | |
|     a: TensorLikeType,
 | |
|     dim: int,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
| ) -> TensorLikeType:
 | |
|     result_dtype = dtype or a.dtype
 | |
|     computation_dtype = utils.get_computation_dtype(result_dtype)
 | |
|     a_ = _maybe_convert_to_dtype(a, computation_dtype)
 | |
|     if a.numel() == 0:
 | |
|         a_exp = exp(a_)
 | |
|     else:
 | |
|         a_max = amax(a_, dim, keepdim=True)
 | |
|         a_exp = exp(a_ - a_max)
 | |
|     return _maybe_convert_to_dtype(
 | |
|         true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
 | |
|     )  # type: ignore[return-value]
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @out_wrapper()
 | |
| def hstack(tensors: TensorSequenceType) -> TensorLikeType:
 | |
|     torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
 | |
|     aligned_tensors = atleast_1d(*tensors)
 | |
|     if aligned_tensors[0].ndim == 1:
 | |
|         return cat(aligned_tensors, 0)
 | |
|     return cat(aligned_tensors, 1)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| @out_wrapper()
 | |
| def vstack(tensors: TensorSequenceType) -> TensorLikeType:
 | |
|     torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
 | |
|     aligned_tensors = atleast_2d(*tensors)
 | |
|     return cat(aligned_tensors, 0)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
 | |
|     dim = utils.canonicalize_dim(a.ndim, dim)
 | |
|     torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
 | |
|     return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.unbind)
 | |
| def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
 | |
|     dim = utils.canonicalize_dim(t.ndim, dim)
 | |
|     torch._check_index(
 | |
|         len(t.shape) > 0,
 | |
|         lambda: "Dimension specified as 0 but tensor has no dimensions",
 | |
|     )
 | |
| 
 | |
|     # Note: t.shape[dim] can't be dynamic or unbacked, even if we use guard_or_false here we will fail
 | |
|     # later in the split since t.shape[dim] control the number of output tensors.
 | |
|     if t.shape[dim] == 0:
 | |
|         return ()
 | |
|     else:
 | |
|         return tuple(
 | |
|             torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
 | |
|         )
 | |
| 
 | |
| 
 | |
| @out_wrapper()
 | |
| def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
 | |
|     return x.clone(memory_format=torch.contiguous_format).index_copy_(
 | |
|         dim, index, tensor
 | |
|     )
 | |
| 
 | |
| 
 | |
| def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
 | |
|     dim = utils.canonicalize_dims(x.ndim, dim)
 | |
|     torch._check(
 | |
|         index.ndim <= 1,
 | |
|         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
 | |
|     )
 | |
|     # Treat scalars as elements of \R^1
 | |
|     y = x.unsqueeze(0) if x.ndim == 0 else x
 | |
|     idx = (slice(None),) * dim + (index,)
 | |
|     y[idx] = tensor
 | |
|     return x
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.index_fill)
 | |
| @out_wrapper()
 | |
| def index_fill(
 | |
|     x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
 | |
| ):
 | |
|     return _index_fill(x, dim, index, value, inplace=False)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.index_fill_)
 | |
| def index_fill_(
 | |
|     x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
 | |
| ):
 | |
|     return _index_fill(x, dim, index, value, inplace=True)
 | |
| 
 | |
| 
 | |
| def _index_fill(
 | |
|     x: TensorLike,
 | |
|     dim: int,
 | |
|     index: TensorLike,
 | |
|     value: Union[NumberType, TensorLike],
 | |
|     *,
 | |
|     inplace: bool,
 | |
| ):
 | |
|     torch._check(
 | |
|         index.ndim <= 1,
 | |
|         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
 | |
|     )
 | |
|     if isinstance(value, TensorLike):
 | |
|         torch._check(
 | |
|             value.ndim == 0,
 | |
|             lambda: "Only supports 0-dimensional value tensor. "  # type: ignore[union-attr]
 | |
|             f"Got a tensor with {value.ndim} dimensions.",
 | |
|         )  # type: ignore[arg-type]
 | |
|     else:
 | |
|         value = torch.scalar_tensor(
 | |
|             value,
 | |
|             dtype=x.dtype,
 | |
|             layout=x.layout,
 | |
|             device=x.device,  # type: ignore[arg-type]
 | |
|         )
 | |
| 
 | |
|     # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them
 | |
|     zero_dim = x.ndim == 0
 | |
|     y = x.unsqueeze(0) if zero_dim else x
 | |
|     # index_copy does not broadcast on value so we have to do it manually
 | |
|     shape = list(y.shape)
 | |
|     shape[dim] = index.numel()
 | |
|     value = value.expand(shape)
 | |
|     index_copy = Tensor.index_copy_ if inplace else torch.index_copy
 | |
|     out = index_copy(y, dim, index, value)  # type: ignore[operator]
 | |
|     if inplace:
 | |
|         return x
 | |
|     else:
 | |
|         if zero_dim:
 | |
|             # The clone is necessary so that it returns a fresh tensor rather than a view
 | |
|             out = out.squeeze(0).clone()
 | |
|         # index_fill preserves the strides. index_copy always returns contiguous tensors
 | |
|         if out.stride() != x.stride():
 | |
|             new_out = torch.empty_like(x)
 | |
|             new_out.copy_(out)
 | |
|             out = new_out
 | |
|         return out
 | |
| 
 | |
| 
 | |
| @out_wrapper()
 | |
| def index_add(
 | |
|     x: TensorLike,
 | |
|     dim: int,
 | |
|     index: TensorLike,
 | |
|     tensor: TensorLike,
 | |
|     *,
 | |
|     alpha: NumberType = 1,
 | |
| ):
 | |
|     # index_add always returns a new contiguous tensor
 | |
|     return x.clone(memory_format=torch.contiguous_format).index_add_(
 | |
|         dim,
 | |
|         index,
 | |
|         tensor,
 | |
|         alpha=alpha,  # type: ignore[arg-type]
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.index_select)
 | |
| @out_wrapper()
 | |
| def index_select(x: TensorLike, dim: int, index: TensorLike):
 | |
|     dim = utils.canonicalize_dims(x.ndim, dim)
 | |
|     torch._check(
 | |
|         index.ndim <= 1,
 | |
|         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
 | |
|     )
 | |
|     if index.ndim == 0:
 | |
|         index = index.unsqueeze(0)
 | |
|     if x.ndim == 0:
 | |
|         # Treat scalars as elements of \R^1
 | |
|         # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction
 | |
|         return torch.empty_like(x).index_copy(0, index, x.expand_as(index))
 | |
| 
 | |
|     idx = (slice(None),) * dim + (index,)
 | |
|     return x[idx]
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.squeeze.dims)
 | |
| def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
 | |
| 
 | |
|     if dim is None:
 | |
|         dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
 | |
|         return prims.squeeze(a, dims) if dims else prims.view_of(a)
 | |
| 
 | |
|     ndim = a.ndim
 | |
|     dim = utils.canonicalize_dims(ndim, dim)
 | |
|     dims = (dim,) if isinstance(dim, Dim) else dim
 | |
|     # Short-circuits if the tensor has no dimensions
 | |
|     if ndim == 0:
 | |
|         assert len(dims) == 0 or dims == (0,)
 | |
|         return prims.view_of(a)
 | |
| 
 | |
|     # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
 | |
|     dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1))
 | |
|     if len(dims) == 0:
 | |
|         return prims.view_of(a)
 | |
|     if len(dims) == 1:
 | |
|         return prims.squeeze(a, dims)
 | |
|     dims_list = list(dims)
 | |
|     dims_list = sorted(dims_list, reverse=True)
 | |
|     for i in dims_list:
 | |
|         a = squeeze(a, i)
 | |
|     return a
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.split_with_sizes)
 | |
| def split_with_sizes(
 | |
|     self: Tensor, split_sizes: list[int], dim: int = 0
 | |
| ) -> list[Tensor]:
 | |
|     # NB: Perform the check_is_size tests first so that the
 | |
|     # sum test does not try to do a replacement
 | |
|     for i in range(len(split_sizes)):
 | |
|         torch._check_is_size(
 | |
|             split_sizes[i],
 | |
|             lambda: "split_with_sizes expects split_sizes have only non-negative entries",
 | |
|         )
 | |
|     torch._check_with(
 | |
|         ValueError,
 | |
|         builtins.sum(split_sizes) == self.shape[dim],
 | |
|         lambda: f"Split sizes add up to {builtins.sum(split_sizes)} but got the tensor's size of {self.shape[dim]}",
 | |
|     )
 | |
| 
 | |
|     splits = []
 | |
|     offset = self.storage_offset()
 | |
| 
 | |
|     for split_size in split_sizes:
 | |
|         new_shape = list(self.shape)
 | |
|         new_shape[dim] = split_size
 | |
|         # We reimplement narrow here to avoid a lot of checks in the
 | |
|         # decomposition of narrow which calls slice_in_dim and slice
 | |
|         splits.append(self.as_strided(new_shape, self.stride(), offset))
 | |
|         offset = offset + self.stride()[dim] * split_size
 | |
|     return splits
 | |
| 
 | |
| 
 | |
| # Note: does not work with TensorMetas because of data-dependent control-flow
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def tensor_split(
 | |
|     a: TensorLikeType,
 | |
|     indices_or_sections: Union[Tensor, DimsType],
 | |
|     dim: int = 0,
 | |
| ) -> tuple[TensorLikeType, ...]:
 | |
|     _dim = utils.canonicalize_dim(a.ndim, dim)
 | |
|     if a.ndim == 0:
 | |
|         msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
 | |
|         raise ValueError(msg)
 | |
| 
 | |
|     # If indices_or_sections is a tensor, it must be a CPU Long tensor
 | |
|     if isinstance(indices_or_sections, TensorLike):
 | |
|         if not indices_or_sections.device.type == "cpu":
 | |
|             msg = (
 | |
|                 f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, "
 | |
|                 f"but received one on {indices_or_sections.device}"
 | |
|             )
 | |
|             raise ValueError(msg)
 | |
|         if indices_or_sections.dtype != torch.long:
 | |
|             msg = (
 | |
|                 "tensor_split: if indices_or_sections is a tensor it must have long dtype, "
 | |
|                 f" but received one with dtype {indices_or_sections.dtype}"
 | |
|             )
 | |
|             raise ValueError(msg)
 | |
| 
 | |
|     # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
 | |
|     if isinstance(indices_or_sections, IntLike) or (
 | |
|         isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
 | |
|     ):
 | |
|         sections: int = (
 | |
|             indices_or_sections  # type: ignore[assignment]
 | |
|             if isinstance(indices_or_sections, Number)
 | |
|             else indices_or_sections.item()
 | |
|         )
 | |
| 
 | |
|         if sections <= 0:
 | |
|             msg = f"tensor_split: number of sections must be greater than 0, but was {sections}"
 | |
|             raise ValueError(msg)
 | |
| 
 | |
|         dim_size = a.shape[_dim]
 | |
|         min_split_size = math.floor(dim_size / sections)
 | |
|         num_splits_one_extra = dim_size % sections
 | |
| 
 | |
|         split_sizes = []
 | |
|         for split_idx in range(sections):
 | |
|             split_size = (
 | |
|                 min_split_size + 1
 | |
|                 if (split_idx < num_splits_one_extra)
 | |
|                 else min_split_size
 | |
|             )
 | |
|             split_sizes.append(split_size)
 | |
| 
 | |
|         return tuple(aten.split_with_sizes(a, split_sizes, dim=_dim))
 | |
|     # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits
 | |
|     else:
 | |
|         indices = indices_or_sections
 | |
|         if isinstance(indices_or_sections, TensorLike):
 | |
|             if indices_or_sections.ndim != 1:
 | |
|                 msg = (
 | |
|                     "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, "
 | |
|                     f"but received a tensor with {indices_or_sections.ndim} dimensions"
 | |
|                 )
 | |
|                 raise ValueError(msg)
 | |
| 
 | |
|             indices = indices_or_sections.tolist()
 | |
| 
 | |
|         indices = [0] + list(indices) + [a.shape[_dim]]
 | |
|         split_sizes = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)]
 | |
|         return tuple(aten.split_with_sizes(a, split_sizes, dim=_dim))
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def hsplit(
 | |
|     a: TensorLikeType, indices_or_sections: DimsType
 | |
| ) -> tuple[TensorLikeType, ...]:
 | |
|     torch._check(
 | |
|         a.ndim >= 1,
 | |
|         lambda: (
 | |
|             "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
 | |
|             + str(a.ndim)
 | |
|             + " dimensions!"
 | |
|         ),
 | |
|     )
 | |
|     dim = 0 if a.ndim == 1 else 1
 | |
|     if isinstance(indices_or_sections, IntLike):
 | |
|         split_size = indices_or_sections
 | |
|         torch._check(
 | |
|             (split_size != 0 and a.shape[dim] % split_size == 0),
 | |
|             lambda: (
 | |
|                 "torch.hsplit attempted to split along dimension "
 | |
|                 + str(dim)
 | |
|                 + ", but the size of the dimension "
 | |
|                 + str(a.shape[dim])
 | |
|                 + " is not divisible by the split_size "
 | |
|                 + str(split_size)
 | |
|                 + "!"
 | |
|             ),
 | |
|         )
 | |
|         return tensor_split(a, split_size, dim)
 | |
| 
 | |
|     torch._check_type(
 | |
|         isinstance(indices_or_sections, (list, tuple)),
 | |
|         lambda: (
 | |
|             "hsplit(): received an invalid combination of arguments. "
 | |
|             "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
 | |
|             f"but got type {type(indices_or_sections)}"
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     split_sizes = indices_or_sections
 | |
|     return tensor_split(a, split_sizes, dim)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def vsplit(
 | |
|     a: TensorLikeType, indices_or_sections: DimsType
 | |
| ) -> tuple[TensorLikeType, ...]:
 | |
|     torch._check(
 | |
|         a.ndim >= 2,
 | |
|         lambda: (
 | |
|             "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
 | |
|             + str(a.ndim)
 | |
|             + " dimensions!"
 | |
|         ),
 | |
|     )
 | |
|     if isinstance(indices_or_sections, IntLike):
 | |
|         split_size = indices_or_sections
 | |
|         torch._check(
 | |
|             (split_size != 0 and a.shape[0] % split_size == 0),
 | |
|             lambda: (
 | |
|                 f"torch.vsplit attempted to split along dimension 0"
 | |
|                 f", but the size of the dimension "
 | |
|                 f"{a.shape[0]}"
 | |
|                 f" is not divisible by the split_size "
 | |
|                 f"{split_size}"
 | |
|                 f"!"
 | |
|             ),
 | |
|         )
 | |
|         return tensor_split(a, split_size, 0)
 | |
| 
 | |
|     torch._check_type(
 | |
|         isinstance(indices_or_sections, (list, tuple)),
 | |
|         lambda: (
 | |
|             "vsplit(): received an invalid combination of arguments. "
 | |
|             "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
 | |
|             f"but got type {type(indices_or_sections)}"
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     split_sizes = indices_or_sections
 | |
|     return tensor_split(a, split_sizes, 0)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.diag.out)
 | |
| @out_wrapper()
 | |
| def diag(
 | |
|     self: TensorLikeType,
 | |
|     offset: int = 0,
 | |
| ) -> TensorLikeType:
 | |
|     ndim = self.dim()
 | |
|     torch._check(
 | |
|         ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
 | |
|     )
 | |
|     if ndim == 1:
 | |
|         return torch.diag_embed(self, offset)
 | |
|     else:
 | |
|         return torch.diagonal_copy(self, offset)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.diagonal_scatter)
 | |
| @out_wrapper()
 | |
| def diagonal_scatter(
 | |
|     input: TensorLikeType,
 | |
|     src: TensorLikeType,
 | |
|     offset: int = 0,
 | |
|     dim1: int = 0,
 | |
|     dim2: int = 1,
 | |
| ) -> TensorLikeType:
 | |
|     out = utils.clone_preserve_strides(input)
 | |
|     diag = out.diagonal(offset, dim1, dim2)
 | |
|     torch._check(
 | |
|         diag.shape == src.shape,
 | |
|         lambda: "expected src to have a size equal to the diagonal of the input."
 | |
|         f"Got {src.shape} for a diagonal of shape {diag.shape}",
 | |
|     )
 | |
|     copy_to(diag, src)
 | |
|     return out
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.diagonal)
 | |
| def diagonal(
 | |
|     self: TensorLikeType,
 | |
|     offset: int = 0,
 | |
|     dim1: int = 0,
 | |
|     dim2: int = 1,
 | |
| ) -> TensorLikeType:
 | |
|     """
 | |
|     Reference implementation of torch.diagonal
 | |
|     """
 | |
|     num_dims = self.dim()
 | |
|     dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
 | |
|     dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)
 | |
| 
 | |
|     torch._check(
 | |
|         dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
 | |
|     )
 | |
| 
 | |
|     storage_offset = self.storage_offset()
 | |
| 
 | |
|     if offset >= 0:
 | |
|         diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0)
 | |
|     else:
 | |
|         diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0)
 | |
| 
 | |
|     if diag_size > 0:
 | |
|         if offset >= 0:
 | |
|             storage_offset += offset * self.stride()[dim2]
 | |
|         else:
 | |
|             storage_offset -= offset * self.stride()[dim1]
 | |
| 
 | |
|     sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)]
 | |
|     sizes.append(diag_size)
 | |
| 
 | |
|     strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)]
 | |
|     strides.append(self.stride()[dim1] + self.stride()[dim2])
 | |
| 
 | |
|     result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset)
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.diag_embed)
 | |
| @out_wrapper()
 | |
| def diag_embed(
 | |
|     t: TensorLikeType,
 | |
|     offset: int = 0,
 | |
|     dim1: int = -2,
 | |
|     dim2: int = -1,
 | |
| ) -> TensorLikeType:
 | |
|     """
 | |
|     Reference implementation of torch.diag_embed
 | |
|     """
 | |
|     # convert from negative dims
 | |
|     rank = t.ndim + 1
 | |
|     dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
 | |
|     dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)
 | |
| 
 | |
|     # as per the docs, exchanging dims is equivalent to changing the sign of
 | |
|     # offset
 | |
|     if dim1 > dim2:
 | |
|         dim1, dim2 = dim2, dim1
 | |
|         offset = -offset
 | |
| 
 | |
|     torch._check(
 | |
|         dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
 | |
|     )
 | |
| 
 | |
|     # as per the docs, the size of last dim is placed at dim1 and dim2
 | |
|     last_dim = t.size(-1)
 | |
| 
 | |
|     if offset != 0:
 | |
|         # add padding to match the new size
 | |
|         t_shape = list(t.shape)
 | |
|         t_shape[-1] = builtins.abs(offset)
 | |
|         z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False)
 | |
|         pair = (z, t) if offset > 0 else (t, z)
 | |
|         t = torch.cat(pair, dim=-1)
 | |
|         # make sure the diagonal always has the same size
 | |
|         last_dim += builtins.abs(offset)
 | |
| 
 | |
|     # preserve original data, but place 1 at dim1 and move last dim to dim2
 | |
|     t = t.unsqueeze(dim1).movedim(-1, dim2)
 | |
| 
 | |
|     # generate ranges shifting indices based on offset
 | |
|     a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64)
 | |
|     b_range = torch.arange(
 | |
|         offset, last_dim + offset, device=t.device, dtype=torch.int64
 | |
|     )
 | |
| 
 | |
|     # broadcast
 | |
|     cond = a_range == b_range.unsqueeze(-1)
 | |
|     cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))]
 | |
|     cond = cond.reshape(cond_shape)
 | |
| 
 | |
|     # aten.diag_embed always returns a new contiguous tensor
 | |
|     # contiguous() is needed to correctly model the output stride
 | |
|     return utils.mask_tensor(cond, t).contiguous()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.block_diag)
 | |
| @out_wrapper()
 | |
| def _block_diag_iterable(tensors: list[TensorLikeType]) -> TensorLikeType:
 | |
|     """
 | |
|     Reference implementation of torch.block_diag
 | |
|     """
 | |
|     tensors_2d = [
 | |
|         tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors
 | |
|     ]
 | |
| 
 | |
|     ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d)
 | |
|     device = tensors_2d[0].device
 | |
| 
 | |
|     result = []
 | |
| 
 | |
|     col_start = 0
 | |
|     for i, tensor in enumerate(tensors_2d):
 | |
|         torch._check(
 | |
|             tensor.dim() == 2,
 | |
|             lambda: "Input tensors must have 2 or fewer dimensions. "
 | |
|             f"Input {i} has {tensor.dim()} dimensions",
 | |
|         )
 | |
|         torch._check(
 | |
|             tensor.device == device,
 | |
|             lambda: "Input tensors must all be on the same device. "
 | |
|             f"Input 0 is on device {device} and input {i} is on device {tensor.device}.",
 | |
|         )
 | |
|         row, col = tensor.shape
 | |
|         left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype)
 | |
|         right = torch.zeros(
 | |
|             (row, ncols - col_start - col), device=device, dtype=tensor.dtype
 | |
|         )
 | |
|         result += [torch.cat((left, tensor, right), dim=1)]
 | |
|         col_start += col
 | |
| 
 | |
|     return torch.cat(result, dim=0)
 | |
| 
 | |
| 
 | |
| def block_diag(*tensors: list[TensorLikeType]) -> TensorLikeType:
 | |
|     """
 | |
|     This is used as an input to PythonRefInfo. `torch.block_diag`
 | |
|     expects arguments splatted, but `aten.block_diag` expects only
 | |
|     one argument that is a list of Tensors.
 | |
|     """
 | |
|     return _block_diag_iterable(tensors)  # type: ignore[arg-type]
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
 | |
|     if a.ndim < 3:
 | |
|         raise RuntimeError(
 | |
|             f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
 | |
|         )
 | |
|     if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
 | |
|         raise RuntimeError(
 | |
|             "torch.dsplit attempted to split along dimension 2, "
 | |
|             + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
 | |
|         )
 | |
|     return tensor_split(a, sections, 2)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.t.default)
 | |
| def t(a: TensorLikeType):
 | |
|     # TODO: Add sparse support
 | |
|     # if a.is_sparse:
 | |
|     #     sparse_dim = a.sparse_dim()
 | |
|     #     dense_dim = a.dense_dim()
 | |
|     #     if not (sparse_dim <= 2 and dense_dim == 0):
 | |
|     #         raise RuntimeError(
 | |
|     #             f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and"
 | |
|     #             f"{dense_dim} dense dimensions"
 | |
|     #         )
 | |
|     if a.ndim > 2:
 | |
|         raise RuntimeError(
 | |
|             f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D"
 | |
|         )
 | |
|     return torch.transpose(a, 0, 0 if a.ndim < 2 else 1)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def T(a: TensorLikeType) -> TensorLikeType:
 | |
|     # n != 2 && n != 0 is deprecated in regular PyTorch.
 | |
|     torch._check(
 | |
|         a.ndim in (0, 2),
 | |
|         lambda: (
 | |
|             "The use of `x.T` on tensors of dimension other than 0 or 2 "
 | |
|             "to reverse their shape is not supported."
 | |
|         ),
 | |
|     )
 | |
|     return a.t()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.alias)
 | |
| def alias(a: TensorLikeType) -> TensorLikeType:
 | |
|     return prims.view_of(a)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.transpose)
 | |
| def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
 | |
|     _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1))  # type: ignore[misc]
 | |
| 
 | |
|     if a.ndim <= 1 or dim0 == dim1:
 | |
|         return aten.alias.default(a)
 | |
| 
 | |
|     _permutation = list(range(0, a.ndim))
 | |
|     _permutation[_dim0] = _dim1
 | |
|     _permutation[_dim1] = _dim0
 | |
|     return torch.permute(a, _permutation)
 | |
| 
 | |
| 
 | |
| # Aliases for transpose
 | |
| swap_axes = transpose
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.unfold)
 | |
| def unfold(
 | |
|     self: TensorLikeType, dimension: int, size: int, step: int
 | |
| ) -> TensorLikeType:
 | |
|     shape, strides = _get_unfold_shape_stride(
 | |
|         self.shape, self.stride(), dimension, size, step
 | |
|     )
 | |
|     return self.as_strided(shape, strides)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.unfold_copy)
 | |
| @out_wrapper()
 | |
| def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int):
 | |
|     return self.unfold(dimension, size, step).clone(
 | |
|         memory_format=torch.contiguous_format
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _cumsumprod_common(
 | |
|     func,
 | |
|     init,
 | |
|     a: TensorLikeType,
 | |
|     dim: int,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> TensorLikeType:
 | |
|     # We implement all the kwargs of a reduction. ATen just handles dtype
 | |
|     # nb. This decomposition may not be as efficient as a backend-specific implementation
 | |
|     ndim = a.ndim
 | |
|     dim = utils.canonicalize_dim(ndim, dim)
 | |
|     if ndim == 0:
 | |
|         return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out)
 | |
|     a = a.unsqueeze(dim + 1)
 | |
|     rg = torch.arange(a.shape[dim], device=a.device)
 | |
|     mask = rg.unsqueeze(1) <= rg
 | |
|     for _ in range(ndim - dim - 1):
 | |
|         mask = mask.unsqueeze(-1)
 | |
|     masked_a = torch.where(mask, a, init)
 | |
|     return func(masked_a, dim=dim, dtype=dtype, out=out)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.cumsum)
 | |
| def cumsum(
 | |
|     a: TensorLikeType,
 | |
|     dim: int,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> TensorLikeType:
 | |
|     return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.cumprod)
 | |
| def cumprod(
 | |
|     a: TensorLikeType,
 | |
|     dim: int,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> TensorLikeType:
 | |
|     return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out)
 | |
| 
 | |
| 
 | |
| # Note: although squeeze is documented as having the out= kwarg it doesn't
 | |
| @register_decomposition(aten.unsqueeze)
 | |
| def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
 | |
|     # Note that unsqueeze canonicalizes with rank + 1 because it allows
 | |
|     # a new innermost dimension to be specified
 | |
|     ndim = a.ndim + 1
 | |
|     dim = utils.canonicalize_dim(ndim, dim)
 | |
|     return prims.expand_dims(a, (dim,), ndim=ndim)
 | |
| 
 | |
| 
 | |
| # NOTE: shape is a vararg because Tensor.reshape can be called with as
 | |
| # Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view
 | |
| # doesn't support unpacked shapes
 | |
| # TODO: Turn this into a decomposition (currently fails on reshape meta tests)
 | |
| @register_decomposition(aten.view.default)
 | |
| def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
 | |
|     return _reshape_view_helper(a, *shape, allow_copy=False)
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
 | |
|     return self.view(other.size())
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def ravel(a: TensorLikeType) -> TensorLikeType:
 | |
|     return reshape(a, (-1,))
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| # missing ref impl. for aten.gather
 | |
| @out_wrapper()
 | |
| def take_along_dim(
 | |
|     a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None
 | |
| ) -> torch.Tensor:
 | |
|     torch._check(
 | |
|         a.ndim == indices.ndim,
 | |
|         lambda: (
 | |
|             "torch.take_along_dim(): input and indices should have the same "
 | |
|             f"number of dimensions, but got {a.ndim} dimensions for input, and "
 | |
|             f"{indices.ndim} dimensions for indices"
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     torch._check(
 | |
|         utils.is_integer_dtype(indices.dtype),
 | |
|         lambda: (
 | |
|             "torch.take_along_dim(): dtype of indices should be int but got "
 | |
|             f"{indices.dtype} instead"
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     if dim is None:
 | |
|         return torch.gather(a.view(-1), 0, indices.view(-1))
 | |
|     else:
 | |
|         self_sizes = list(a.shape)
 | |
|         self_sizes[dim] = indices.size(dim)
 | |
|         broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size())
 | |
|         indices_broadcast = broadcast_to(indices, broadcast_shape)
 | |
| 
 | |
|         indices_sizes = list(indices.shape)
 | |
|         indices_sizes[dim] = a.size(dim)
 | |
|         broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size())
 | |
|         self_broadcast = broadcast_to(a, broadcast_shape)
 | |
| 
 | |
|         return torch.gather(self_broadcast, dim, indices_broadcast)
 | |
| 
 | |
| 
 | |
| @out_wrapper()
 | |
| def empty(
 | |
|     *shape,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     requires_grad: bool = False,
 | |
|     pin_memory: bool = False,
 | |
|     memory_format: torch.memory_format = torch.contiguous_format,
 | |
| ) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         memory_format != torch.preserve_format,
 | |
|         lambda: "torch.empty: the Preserve memory format is not supported",
 | |
|     )
 | |
| 
 | |
|     shape = utils.extract_shape_from_varargs(shape)
 | |
| 
 | |
|     if memory_format == torch.contiguous_format:
 | |
|         strides = utils.make_contiguous_strides_for(shape)
 | |
|     elif memory_format == torch.channels_last_3d:
 | |
|         strides = utils.make_channels_last_3d_strides_for(shape)
 | |
|     else:  # memory_format == torch.channels_last
 | |
|         torch._check(
 | |
|             memory_format == torch.channels_last,
 | |
|             lambda: f"torch.empty: received an unknown memory format {memory_format}!",
 | |
|         )
 | |
|         strides = utils.make_channels_last_2d_strides_for(shape)
 | |
| 
 | |
|     return torch.empty_strided(
 | |
|         shape,
 | |
|         strides,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @out_wrapper()
 | |
| def empty_permuted(
 | |
|     shape,
 | |
|     physical_layout,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     requires_grad: bool = False,
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     return prims.empty_permuted(
 | |
|         shape,
 | |
|         physical_layout,
 | |
|         dtype=dtype,
 | |
|         device=device,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.new_empty)
 | |
| @out_wrapper()
 | |
| def new_empty(
 | |
|     a: TensorLikeType,
 | |
|     size: ShapeType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     dtype = a.dtype if dtype is None else dtype
 | |
|     layout = a.layout if layout is None else layout
 | |
|     device = a.device if device is None else device
 | |
| 
 | |
|     return torch.empty(
 | |
|         size,
 | |
|         dtype=dtype,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         layout=layout,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.new_empty_strided)
 | |
| @out_wrapper()
 | |
| def new_empty_strided(
 | |
|     a: TensorLikeType,
 | |
|     size: ShapeType,
 | |
|     stride: StrideType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     """
 | |
|     Reference implementation of torch.Tensor.new_empty_strided
 | |
|     """
 | |
| 
 | |
|     dtype = a.dtype if dtype is None else dtype
 | |
|     layout = a.layout if layout is None else layout
 | |
|     device = a.device if device is None else device
 | |
| 
 | |
|     return torch.empty_strided(
 | |
|         size,
 | |
|         stride,
 | |
|         dtype=dtype,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         layout=layout,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.zeros.default)
 | |
| @out_wrapper()
 | |
| def zeros(
 | |
|     *size,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     size = utils.extract_shape_from_varargs(size)
 | |
| 
 | |
|     if dtype is None:
 | |
|         dtype = torch.get_default_dtype()
 | |
| 
 | |
|     return torch.full(
 | |
|         size,
 | |
|         False if dtype == torch.bool else 0,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.new_zeros)
 | |
| @out_wrapper()
 | |
| def new_zeros(
 | |
|     a: TensorLikeType,
 | |
|     size: ShapeType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     dtype = a.dtype if dtype is None else dtype
 | |
|     layout = a.layout if layout is None else layout
 | |
|     device = a.device if device is None else device
 | |
| 
 | |
|     return torch.full(
 | |
|         size,
 | |
|         False if (dtype or a.dtype) == torch.bool else 0,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.ones.default)
 | |
| @out_wrapper()
 | |
| def ones(
 | |
|     *size,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     size = utils.extract_shape_from_varargs(size)
 | |
| 
 | |
|     if dtype is None:
 | |
|         dtype = torch.get_default_dtype()
 | |
| 
 | |
|     return torch.full(
 | |
|         size,
 | |
|         True if dtype == torch.bool else 1,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.new_ones)
 | |
| @out_wrapper()
 | |
| def new_ones(
 | |
|     a: TensorLikeType,
 | |
|     size: ShapeType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     dtype = a.dtype if dtype is None else dtype
 | |
|     layout = a.layout if layout is None else layout
 | |
|     device = a.device if device is None else device
 | |
| 
 | |
|     return torch.full(
 | |
|         size,
 | |
|         True if (dtype or a.dtype) == torch.bool else 1,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.new_full)
 | |
| @out_wrapper()
 | |
| def new_full(
 | |
|     a: TensorLikeType,
 | |
|     size: ShapeType,
 | |
|     fill_value: NumberType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     dtype = a.dtype if dtype is None else dtype
 | |
|     layout = a.layout if layout is None else layout
 | |
|     device = a.device if device is None else device
 | |
| 
 | |
|     return torch.full(
 | |
|         size,
 | |
|         fill_value,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @aten.empty.out.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| def empty_out(
 | |
|     size: TensorLikeType,
 | |
|     out: TensorLikeType,
 | |
|     memory_format: Optional[torch.memory_format] = None,
 | |
| ) -> TensorLikeType:
 | |
|     return out
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.empty_like)
 | |
| @out_wrapper()
 | |
| def empty_like(
 | |
|     a: TensorLikeType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
|     memory_format: torch.memory_format = torch.preserve_format,
 | |
| ) -> TensorLikeType:
 | |
|     dtype = a.dtype if dtype is None else dtype
 | |
|     layout = a.layout if layout is None else layout
 | |
|     device = a.device if device is None else device
 | |
| 
 | |
|     if memory_format != torch.preserve_format:
 | |
|         return torch.empty(
 | |
|             a.shape,
 | |
|             dtype=dtype,
 | |
|             layout=layout,
 | |
|             device=device,
 | |
|             requires_grad=requires_grad,
 | |
|             pin_memory=pin_memory,
 | |
|             memory_format=memory_format,
 | |
|         )
 | |
| 
 | |
|     # memory_format == torch.preserve_format
 | |
|     logical_to_physical_perm = (
 | |
|         utils.compute_elementwise_output_logical_to_physical_perm(a)
 | |
|     )
 | |
|     # identity perm is [2, 1, 0]
 | |
|     return torch.empty_permuted(
 | |
|         a.shape,
 | |
|         logical_to_physical_perm,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.arange.start_step, aten.arange.start_out])
 | |
| @out_wrapper()
 | |
| def arange(
 | |
|     start: NumberType = 0,
 | |
|     end: Optional[NumberType] = None,
 | |
|     step: NumberType = 1,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     utils.check_layout(layout)
 | |
|     utils.check_pin_memory(pin_memory)
 | |
|     device = torch.device(utils.device_or_default(device))
 | |
| 
 | |
|     assert not isinstance(start, complex)
 | |
|     assert not isinstance(end, complex)
 | |
|     assert not isinstance(step, complex)
 | |
| 
 | |
|     # Case: torch.arange(5)
 | |
|     if end is None:
 | |
|         end = start
 | |
|         start = 0
 | |
|     torch._check(step != 0, lambda: "step must be nonzero")
 | |
|     if step > 0:
 | |
|         torch._check(
 | |
|             end >= start,
 | |
|             lambda: "upper bound and lower bound inconsistent with step sign",
 | |
|         )
 | |
|     elif step < 0:
 | |
|         torch._check(
 | |
|             end <= start,
 | |
|             lambda: "upper bound and lower bound inconsistent with step sign",
 | |
|         )
 | |
| 
 | |
|     def is_finite(x):
 | |
|         return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)
 | |
| 
 | |
|     torch._check(
 | |
|         is_finite(start) and is_finite(end),
 | |
|         lambda: f"unsupported range: {start} -> {end}",
 | |
|     )
 | |
|     torch._check(
 | |
|         is_finite(step),
 | |
|         lambda: f"step must be finite but got {step}",
 | |
|     )
 | |
| 
 | |
|     args = (start, end, step)
 | |
|     integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)
 | |
| 
 | |
|     if dtype is None:
 | |
|         dtype = torch.int64 if integer_args else torch.get_default_dtype()
 | |
| 
 | |
|     is_integer = utils.is_integer_dtype(dtype)
 | |
|     if is_integer or integer_args:
 | |
|         xstart = sym_int(start)
 | |
|         xend = sym_int(end)
 | |
|         xstep = sym_int(step)
 | |
| 
 | |
|     # For int64 we truncate arguments to int before calculating length, but
 | |
|     # other integral dtypes we don't. Weird... but needed to match ATen shapes.
 | |
|     if dtype == torch.int64 or integer_args:
 | |
|         # Uses floordiv to avoid ceil in inductor.
 | |
|         sgn = bool(xstep > 0) - bool(xstep < 0)  # type: ignore[possibly-undefined]
 | |
|         length = (xend - xstart + xstep - sgn) // xstep  # type: ignore[possibly-undefined]
 | |
|     else:
 | |
|         length = math.ceil((end - start) / step)
 | |
| 
 | |
|     if is_integer:
 | |
|         return prims.iota(
 | |
|             length,
 | |
|             start=xstart,  # type: ignore[possibly-undefined]
 | |
|             step=xstep,  # type: ignore[possibly-undefined]
 | |
|             dtype=dtype,
 | |
|             device=device,
 | |
|             requires_grad=requires_grad,
 | |
|         )
 | |
| 
 | |
|     index = prims.iota(
 | |
|         length,
 | |
|         start=0,
 | |
|         step=1,
 | |
|         dtype=torch.int64,
 | |
|         device=device,
 | |
|         requires_grad=False,
 | |
|     )
 | |
| 
 | |
|     computation_dtype = (
 | |
|         torch.long if integer_args else utils.get_acc_type(dtype, device)
 | |
|     )
 | |
|     index = _maybe_convert_to_dtype(index, computation_dtype)
 | |
|     result = start + step * index
 | |
|     result = _maybe_convert_to_dtype(result, dtype)
 | |
| 
 | |
|     if requires_grad:
 | |
|         result.requires_grad_(True)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.lerp)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("start", "end", "weight"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]):
 | |
|     inputs = [start, end]
 | |
|     if isinstance(weight, Number):
 | |
|         weight = start.new_full((), weight)  # type: ignore[arg-type]
 | |
|     else:
 | |
|         inputs.append(weight)
 | |
|     assert isinstance(weight, Tensor)  # mypy
 | |
|     # We implement it this way for numerical stability. We assume (in the stability optimisation)
 | |
|     # that 0 <= weight <= 1. We take the abs to deal with complex numbers
 | |
|     # We want to perform operations near zero, which is where floating points are most precise
 | |
|     # thus, we perform the following optimisation:
 | |
|     # If weight.abs() >= 0.5:
 | |
|     #    return (1 - weight) * (start - end) + end
 | |
|     mask = weight.abs() >= 0.5
 | |
|     coeff = torch.where(mask, weight - 1, weight)
 | |
|     base = torch.where(mask, end, start)
 | |
|     output = coeff * (end - start) + base
 | |
|     # make sure the decomposition output's stride is same as non-decomposition path.
 | |
|     stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs))
 | |
|     if output.stride() != stride:
 | |
|         output = prims.copy_strided(output, stride)
 | |
| 
 | |
|     return handle_noncontiguous_outputs(inputs, output)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.linspace)
 | |
| @out_wrapper()
 | |
| def linspace(
 | |
|     start: Union[NumberType, TensorLikeType],
 | |
|     end: Union[NumberType, TensorLikeType],
 | |
|     steps: NumberType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     if isinstance(start, TensorLikeType):
 | |
|         torch._check(
 | |
|             start.dim() == 0,
 | |
|             lambda: "linspace only supports 0-dimensional start and end tensors",
 | |
|         )
 | |
|         start = _maybe_convert_to_dtype(start, torch.float64)
 | |
|     if isinstance(end, TensorLikeType):
 | |
|         torch._check(
 | |
|             end.dim() == 0,
 | |
|             lambda: "linspace only supports 0-dimensional start and end tensors",
 | |
|         )
 | |
|         end = _maybe_convert_to_dtype(end, torch.float64)
 | |
| 
 | |
|     if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
 | |
|         default_complex_dtype = utils.corresponding_complex_dtype(
 | |
|             torch.get_default_dtype()
 | |
|         )
 | |
|         if dtype is None:
 | |
|             dtype = default_complex_dtype
 | |
|         else:
 | |
|             torch._check(
 | |
|                 utils.is_complex_dtype(dtype),
 | |
|                 lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
 | |
|             )
 | |
|     else:
 | |
|         dtype = dtype or torch.get_default_dtype()
 | |
|     assert isinstance(dtype, torch.dtype)
 | |
| 
 | |
|     # steps does not participate in the computation of the dtype
 | |
|     torch._check_type(
 | |
|         isinstance(steps, IntLike),
 | |
|         lambda: f"received an invalid combination of arguments - got \
 | |
| ({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
 | |
|     )
 | |
|     assert isinstance(steps, IntLike)  # for mypy
 | |
|     torch._check(steps >= 0, lambda: "number of steps must be non-negative")
 | |
| 
 | |
|     factory_kwargs = {
 | |
|         "layout": layout,
 | |
|         "device": device,
 | |
|         "pin_memory": pin_memory,
 | |
|         "requires_grad": requires_grad,
 | |
|     }
 | |
|     if steps == 0:
 | |
|         return torch.full((0,), 0, dtype=dtype, **factory_kwargs)  # type: ignore[arg-type]
 | |
|     if steps == 1:
 | |
|         if isinstance(start, TensorLikeType):
 | |
|             empty_tensor = torch.empty((steps,), dtype=dtype, **factory_kwargs)  # type: ignore[arg-type]
 | |
|             return torch.ops.aten.copy.default(empty_tensor, start)
 | |
|         else:
 | |
|             return torch.full((steps,), start, dtype=dtype, **factory_kwargs)  # type: ignore[arg-type]
 | |
| 
 | |
|     # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes
 | |
|     rg = torch.arange(0, steps, **factory_kwargs)  # type: ignore[arg-type]
 | |
| 
 | |
|     # Small types need to be computed in higher precision as this is, at heart, an associative scan
 | |
|     dtype_red = (
 | |
|         torch.int64
 | |
|         if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype))
 | |
|         else dtype
 | |
|     )
 | |
|     computation_dtype, _ = utils.reduction_dtypes(
 | |
|         rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red
 | |
|     )
 | |
|     cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype)
 | |
| 
 | |
|     # We implement torch.lerp without performing rg / (steps - 1) explicitly
 | |
|     # With this we get out[0] == start, out[-1] == end
 | |
|     step = (end - start) / (steps - 1)
 | |
|     out = torch.where(
 | |
|         rg < steps / 2,
 | |
|         start + step * cast_rg(rg),  # type: ignore[arg-type,operator]
 | |
|         end - step * cast_rg((steps - 1) - rg),  # type: ignore[arg-type,operator]
 | |
|     )
 | |
|     return _maybe_convert_to_dtype(out, dtype)  # type: ignore[return-value]
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.logspace)
 | |
| @out_wrapper()
 | |
| def logspace(
 | |
|     start: Union[NumberType, TensorLikeType],
 | |
|     end: Union[NumberType, TensorLikeType],
 | |
|     steps: NumberType,
 | |
|     base: NumberType = 10,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     if dtype is None:
 | |
|         dtype = torch.get_default_dtype()
 | |
| 
 | |
|     # NB: NumPy doesn't have this cast
 | |
|     if prims.utils.is_integer_dtype(dtype):
 | |
|         if isinstance(start, FloatLike):
 | |
|             start = sym_int(start)
 | |
|         elif isinstance(start, TensorLikeType):
 | |
|             torch._check(
 | |
|                 start.dim() == 0,
 | |
|                 lambda: "logspace only supports 0-dimensional start and end tensors",
 | |
|             )
 | |
|             start = _maybe_convert_to_dtype(start, dtype)
 | |
|         if isinstance(end, FloatLike):
 | |
|             end = sym_int(end)
 | |
|         elif isinstance(end, TensorLikeType):
 | |
|             torch._check(
 | |
|                 end.dim() == 0,
 | |
|                 lambda: "logspace only supports 0-dimensional start and end tensors",
 | |
|             )
 | |
|             end = _maybe_convert_to_dtype(end, dtype)
 | |
| 
 | |
|     if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
 | |
|         default_complex_dtype = utils.corresponding_complex_dtype(
 | |
|             torch.get_default_dtype()
 | |
|         )
 | |
|         dtype = default_complex_dtype
 | |
|         _dtype = None  # torch.linspace will update the correct dtype
 | |
|     else:
 | |
|         _dtype = torch.float64
 | |
| 
 | |
|     assert not isinstance(base, complex)  # for mypy
 | |
|     if base < 0:
 | |
|         raise NotImplementedError
 | |
|     ret = torch.linspace(  # type: ignore[misc]
 | |
|         start,  # type: ignore[arg-type]
 | |
|         end,  # type: ignore[arg-type]
 | |
|         steps,  # type: ignore[arg-type]
 | |
|         dtype=_dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
|     return _maybe_convert_to_dtype(torch.pow(base, ret), dtype)  # type: ignore[arg-type,return-value]
 | |
| 
 | |
| 
 | |
| @overload
 | |
| def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @overload
 | |
| def meshgrid(*tensors: TensorLikeType, indexing: str):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.meshgrid)  # type: ignore[misc]
 | |
| def meshgrid(
 | |
|     *tensors: Union[TensorLikeType, list[TensorLikeType], tuple[TensorLikeType]],
 | |
|     indexing: str,
 | |
| ) -> list[TensorLikeType]:
 | |
|     # This ref simultaneously handles two overloads (see stubs above)
 | |
|     # The `indexing` argument is currently optional for torch.meshgrid, but we
 | |
|     # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
 | |
|     if isinstance(tensors[0], (list, tuple)):
 | |
|         assert len(tensors) == 1
 | |
|         tensors = tuple(tensors[0])
 | |
| 
 | |
|     torch._check(
 | |
|         builtins.all(isinstance(a, TensorLike) for a in tensors),
 | |
|         lambda: "meshgrid expects its inputs to be tensors",
 | |
|     )
 | |
| 
 | |
|     torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")
 | |
| 
 | |
|     for i in range(len(tensors) - 1):
 | |
|         torch._check(
 | |
|             tensors[i].dtype == tensors[i + 1].dtype,  # type: ignore[union-attr]
 | |
|             lambda: "meshgrid expects all tensors to have the same dtype",
 | |
|         )
 | |
|         torch._check(
 | |
|             tensors[i].device == tensors[i + 1].device,  # type: ignore[union-attr]
 | |
|             lambda: "meshgrid expects all tensors to have the same device",
 | |
|         )
 | |
| 
 | |
|     swap_first_and_second_tensors = False
 | |
|     if indexing == "xy":
 | |
|         swap_first_and_second_tensors = len(tensors) >= 2
 | |
|         if swap_first_and_second_tensors:
 | |
|             tensors = (tensors[1], tensors[0], *tensors[2:])
 | |
|     else:
 | |
|         torch._check(
 | |
|             indexing == "ij",
 | |
|             lambda: (
 | |
|                 'torch.meshgrid: indexing must be one of "xy" or "ij", '
 | |
|                 f"but received: {indexing}"
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     result_shape: list[int] = []
 | |
|     for t in tensors:
 | |
|         assert isinstance(t, TensorLike)  # mypy
 | |
|         torch._check(
 | |
|             t.ndim == 0 or t.ndim == 1,
 | |
|             lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
 | |
|         )
 | |
|         result_shape.append(t.numel())
 | |
| 
 | |
|     grids: list[TensorLikeType] = []
 | |
|     for i, t in enumerate(tensors):
 | |
|         assert isinstance(t, TensorLike)  # mypy
 | |
|         if t.ndim == 0:
 | |
|             t = t.view((1,))
 | |
|         grids.append(prims.broadcast_in_dim(t, result_shape, (i,)))
 | |
| 
 | |
|     if swap_first_and_second_tensors:
 | |
|         # Swap outputs if we originally swapped at the beginning
 | |
|         grids[0], grids[1] = grids[1], grids[0]
 | |
| 
 | |
|     return grids
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def movedim(
 | |
|     input: TensorLikeType,
 | |
|     source: Union[int, DimsSequenceType],
 | |
|     destination: Union[int, DimsSequenceType],
 | |
| ) -> TensorLikeType:
 | |
|     """
 | |
|     Reference implementation of torch.movedim
 | |
|     """
 | |
|     if type(source) is int:
 | |
|         source = (source,)
 | |
|     if type(destination) is int:
 | |
|         destination = (destination,)
 | |
| 
 | |
|     # Converts to list to produce a compatible error message with core PyTorch,
 | |
|     # which prints sequences in square brackets.
 | |
|     torch._check(
 | |
|         len(source) == len(destination),  # type: ignore[arg-type]
 | |
|         lambda: (
 | |
|             "movedim: Invalid source or destination dims: source "  # type: ignore[arg-type]
 | |
|             f"({list(source)} dims) should contain the same number "  # type: ignore[arg-type]
 | |
|             f"of dims as destination ({list(destination)} dims)"  # type: ignore[arg-type]
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     rank = input.ndim
 | |
|     ss = tuple(utils.canonicalize_dims(rank=rank, indices=source))  # type: ignore[arg-type]
 | |
|     ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination))  # type: ignore[arg-type]
 | |
| 
 | |
|     sss = set(ss)
 | |
|     dss = set(ds)
 | |
| 
 | |
|     # See above on why this converts to list in error messages.
 | |
|     torch._check(
 | |
|         len(ss) == len(sss),
 | |
|         lambda: f"movedim: repeated dim in `source` ({list(source)})",  # type: ignore[arg-type]
 | |
|     )
 | |
|     torch._check(
 | |
|         len(ds) == len(dss),
 | |
|         lambda: f"movedim: repeated dim in `destination` ({list(destination)})",  # type: ignore[arg-type]
 | |
|     )
 | |
| 
 | |
|     m = dict(zip(ds, ss))
 | |
|     dims = []
 | |
|     si = 0  # source index
 | |
|     for di in range(rank):
 | |
|         # check if the destination index is in the mapping
 | |
|         s = m.get(di)
 | |
|         if s is not None:
 | |
|             # insert source index if found
 | |
|             dims.append(s)
 | |
|         else:
 | |
|             # insert source index sequentially, skipping indices from the mapping
 | |
|             while si in sss:
 | |
|                 si += 1
 | |
|             dims.append(si)
 | |
|             si += 1
 | |
| 
 | |
|     result = torch.permute(input, tuple(dims))
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| # NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints
 | |
| @register_decomposition(aten.empty_strided)
 | |
| @out_wrapper()
 | |
| def empty_strided(
 | |
|     shape: Union[ShapeType, tuple[ShapeType]],
 | |
|     strides: StrideType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     requires_grad: bool = False,
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     # Layout == strided, pin_memory is False
 | |
|     utils.check_layout(layout)
 | |
|     utils.check_pin_memory(pin_memory)
 | |
| 
 | |
|     shape = utils.extract_shape_from_varargs(shape)
 | |
|     dtype = torch.get_default_dtype() if dtype is None else dtype
 | |
|     device = torch.device("cpu") if device is None else device
 | |
| 
 | |
|     return prims.empty_strided(
 | |
|         shape,
 | |
|         strides,
 | |
|         dtype=dtype,
 | |
|         device=device,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.eye)
 | |
| @out_wrapper()
 | |
| def eye(
 | |
|     n: int,
 | |
|     m: Optional[int] = None,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,  # TODO: unused
 | |
| ) -> TensorLikeType:
 | |
|     """
 | |
|     Reference implementation of torch.eye
 | |
|     """
 | |
|     if m is None:
 | |
|         m = n
 | |
| 
 | |
|     torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
 | |
|     torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")
 | |
| 
 | |
|     range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
 | |
|     range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)
 | |
| 
 | |
|     cond = range_n.unsqueeze(-1) == range_m
 | |
|     if dtype is torch.bool:
 | |
|         return cond
 | |
|     else:
 | |
|         one = torch.ones(
 | |
|             (1,),
 | |
|             dtype=dtype,
 | |
|             layout=layout,
 | |
|             device=device,
 | |
|             pin_memory=pin_memory,
 | |
|             requires_grad=False,
 | |
|         )
 | |
|         return torch.where(cond, one, 0)
 | |
|     # TODO: Use requires_grad.  All refs taking the requires_grad kwarg must
 | |
|     # return a leaf tensor.
 | |
|     # result.requires_grad_(requires_grad)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.full.default, aten.full.out])
 | |
| @out_wrapper()
 | |
| def full(
 | |
|     shape: ShapeType,
 | |
|     fill_value: NumberType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     utils.check_layout(layout)
 | |
|     utils.check_pin_memory(pin_memory)
 | |
| 
 | |
|     dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
 | |
|     device = device if device is not None else torch.device("cpu")
 | |
| 
 | |
|     e = empty(
 | |
|         shape,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
|     return torch.fill(e, fill_value)  # type: ignore[arg-type]
 | |
| 
 | |
| 
 | |
| def full_like(
 | |
|     a: TensorLikeType,
 | |
|     fill_value: NumberType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
|     memory_format: torch.memory_format = torch.preserve_format,
 | |
| ) -> TensorLikeType:
 | |
|     e = torch.empty_like(
 | |
|         a,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|         memory_format=memory_format,
 | |
|     )
 | |
|     return fill(e, fill_value)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.zeros_like)
 | |
| @out_wrapper()
 | |
| def zeros_like(
 | |
|     a: TensorLikeType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
|     memory_format: torch.memory_format = torch.preserve_format,
 | |
| ) -> TensorLikeType:
 | |
|     return torch.full_like(
 | |
|         a,
 | |
|         False if (dtype or a.dtype) == torch.bool else 0,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|         memory_format=memory_format,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.ones_like)
 | |
| @out_wrapper()
 | |
| def ones_like(
 | |
|     a: TensorLikeType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
|     requires_grad: bool = False,
 | |
|     memory_format: torch.memory_format = torch.preserve_format,
 | |
| ) -> TensorLikeType:
 | |
|     return torch.full_like(
 | |
|         a,
 | |
|         True if (dtype or a.dtype) == torch.bool else 1,
 | |
|         dtype=dtype,
 | |
|         layout=layout,
 | |
|         device=device,
 | |
|         pin_memory=pin_memory,
 | |
|         requires_grad=requires_grad,
 | |
|         memory_format=memory_format,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.randn.default)
 | |
| @out_wrapper()
 | |
| def randn(
 | |
|     *shape,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     layout: Optional[torch.layout] = None,
 | |
|     requires_grad: bool = False,
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     utils.check_pin_memory(pin_memory)
 | |
| 
 | |
|     shape_ = utils.extract_shape_from_varargs(shape)
 | |
| 
 | |
|     dtype = utils.dtype_or_default(dtype)
 | |
|     device = utils.device_or_default(device)
 | |
| 
 | |
|     return prims.normal(
 | |
|         shape_,
 | |
|         mean=0.0,
 | |
|         std=1.0,
 | |
|         dtype=dtype,
 | |
|         device=device,
 | |
|         requires_grad=requires_grad,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def scalar_tensor(
 | |
|     a: NumberType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[DeviceLikeType] = None,
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     utils.check_layout(layout)
 | |
|     utils.check_pin_memory(pin_memory)
 | |
|     dtype = dtype if dtype is not None else utils.type_to_dtype(type(a))
 | |
|     device = device if device is not None else torch.device("cpu")
 | |
|     return prims.scalar_tensor(a, dtype=dtype, device=device)
 | |
| 
 | |
| 
 | |
| #
 | |
| # Randomness References
 | |
| #
 | |
| 
 | |
| 
 | |
| def _uniform_helper(
 | |
|     shape: ShapeType,
 | |
|     low: Union[bool, int, float] = 0.0,
 | |
|     high: Union[bool, int, float] = 1.0,
 | |
|     *,
 | |
|     dtype: torch.dtype,
 | |
|     device: DeviceLikeType,
 | |
| ) -> TensorLikeType:
 | |
|     utils.validate_shape(shape)
 | |
| 
 | |
|     assert isinstance(low, Number)
 | |
|     assert isinstance(high, Number)
 | |
|     low = sym_float(low)
 | |
|     high = sym_float(high)
 | |
| 
 | |
|     assert isinstance(dtype, torch.dtype)
 | |
|     device = utils.canonicalize_device(device)
 | |
| 
 | |
|     return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.masked_fill)
 | |
| @out_wrapper()
 | |
| def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType):
 | |
|     python_type = utils.dtype_to_type(a.dtype)
 | |
|     if isinstance(value, Number):
 | |
|         value_type = type(value)
 | |
|     else:
 | |
|         # NOTE: Could not use value = item(value) as it resulted in
 | |
|         # RuntimeError: Cannot cast FakeTensor(cpu) to number
 | |
|         value_ndim = value.ndim
 | |
|         torch._check(
 | |
|             value_ndim == 0,
 | |
|             lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
 | |
|         )
 | |
|         # `masked_fill` allows cpu scalar to be moved to cuda, xpu and hpu but not otherwise.
 | |
|         is_cpu_scalar = (
 | |
|             a.device.type
 | |
|             in ["cuda", "xpu", "mps", torch._C._get_privateuse1_backend_name(), "hpu"]
 | |
|             and value.device.type == "cpu"
 | |
|         )
 | |
|         torch._check(
 | |
|             is_cpu_scalar or value.device == a.device,
 | |
|             lambda: "Expected `value` to be on same device as `a`",
 | |
|         )
 | |
|         value_type = utils.dtype_to_type(value.dtype)
 | |
| 
 | |
|     if value_type is complex:
 | |
|         # only downcasting from complex to lower type is not allowed.
 | |
|         # We allow casting `value` to lower type for other case
 | |
|         # Eg. float -> int.
 | |
|         # Ref: https://github.com/pytorch/pytorch/issues/79195
 | |
|         torch._check(
 | |
|             utils.is_weakly_lesser_type(value_type, python_type),
 | |
|             lambda: f"could not convert to type {python_type} without overflow",
 | |
|         )
 | |
| 
 | |
|     # Since `where` allows type-promotion,
 | |
|     # cast value to correct type before passing to `where`
 | |
|     value = _maybe_convert_to_dtype(value, a.dtype)
 | |
|     r = torch.where(mask, value, a)  # type: ignore[arg-type]
 | |
| 
 | |
|     # aten.mask_fill always return a new contiguous tensor
 | |
|     # contiguous() is needed to correctly model the output stride
 | |
|     return r.contiguous()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.masked_fill_)
 | |
| def masked_fill_(
 | |
|     a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType
 | |
| ) -> TensorLikeType:
 | |
|     b = torch.masked_fill(a, mask, value)  # type: ignore[arg-type]
 | |
|     a.copy_(b)
 | |
|     return a
 | |
| 
 | |
| 
 | |
| # CompositeImplicitAutograd - don't register decomp
 | |
| def allclose(
 | |
|     a: TensorLikeType,
 | |
|     b: TensorLikeType,
 | |
|     rtol: float = 1e-05,
 | |
|     atol: float = 1e-08,
 | |
|     equal_nan: bool = False,
 | |
| ) -> bool:
 | |
|     """
 | |
|     Reference implementation of torch.allclose
 | |
|     """
 | |
|     _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
 | |
| 
 | |
|     return bool(
 | |
|         torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item()
 | |
|     )
 | |
| 
 | |
| 
 | |
| def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
 | |
|     utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
 | |
|     utils.check_same_dtype(a, b)
 | |
| 
 | |
|     # Shape check
 | |
|     if a.ndim != b.ndim:
 | |
|         return False
 | |
| 
 | |
|     for x, y in zip(a.shape, b.shape):
 | |
|         if x != y:
 | |
|             return False
 | |
| 
 | |
|     # Short-circuits if there are no elements to validate
 | |
|     if a.numel() == 0:
 | |
|         return True
 | |
| 
 | |
|     return item(all(eq(a, b)))  # type: ignore[return-value]
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.norm)
 | |
| @out_wrapper(exact_dtype=True)
 | |
| def norm(
 | |
|     input: TensorLikeType,
 | |
|     p: Optional[Union[float, str]] = "fro",
 | |
|     dim: Optional[DimsType] = None,
 | |
|     keepdim: bool = False,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
| ) -> TensorLikeType:
 | |
|     # In these cases we compute the "Frobenius norm"
 | |
|     if (
 | |
|         p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
 | |
|     ) or p is None:
 | |
|         p = 2
 | |
|     if isinstance(dim, Dim):
 | |
|         dim = [dim]
 | |
|     if isinstance(p, str):
 | |
|         # Here we either call the nuclear norm, or we call matrix_norm with some arguments
 | |
|         # that will throw an error
 | |
|         if dim is None:
 | |
|             dim = tuple(range(input.ndim))
 | |
|         return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype)
 | |
|     else:
 | |
|         return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.trace)
 | |
| @out_wrapper()
 | |
| def trace(self: TensorLikeType) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
 | |
|     )
 | |
|     return torch.sum(torch.diag(self, 0))
 | |
| 
 | |
| 
 | |
| def _make_r_binary_op(base_op):
 | |
|     def rop(
 | |
|         a: Union[TensorLikeType, NumberType],
 | |
|         b: Union[TensorLikeType, NumberType],
 | |
|     ) -> TensorLikeType:
 | |
|         return base_op(b, a)
 | |
| 
 | |
|     return rop
 | |
| 
 | |
| 
 | |
| rtruediv = _make_r_binary_op(true_divide)
 | |
| rfloordiv = _make_r_binary_op(floor_divide)
 | |
| rpow = _make_r_binary_op(pow)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.triu)
 | |
| @out_wrapper()
 | |
| def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
 | |
|     )
 | |
|     h, w = a.shape[-2:]
 | |
|     mask = (
 | |
|         torch.arange(w, device=a.device).unsqueeze(-2)
 | |
|         - torch.arange(h, device=a.device).unsqueeze(-1)
 | |
|     ) >= diagonal
 | |
| 
 | |
|     # aten.triu always returns a new contiguous tensor
 | |
|     # contiguous() is needed to correctly model the output stride
 | |
|     return utils.mask_tensor(mask, a).contiguous()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.tril)
 | |
| @out_wrapper()
 | |
| def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
 | |
|     torch._check(
 | |
|         a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
 | |
|     )
 | |
|     h, w = a.shape[-2:]
 | |
|     mask = (
 | |
|         torch.arange(w, device=a.device).unsqueeze(-2)
 | |
|         - torch.arange(h, device=a.device).unsqueeze(-1)
 | |
|     ) <= diagonal
 | |
| 
 | |
|     # aten.tril always returns a new contiguous tensor
 | |
|     # contiguous() is needed to correctly model the output stride
 | |
|     return utils.mask_tensor(mask, a).contiguous()
 | |
| 
 | |
| 
 | |
| # This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h
 | |
| # The components of the matrix that belong to the lower triangle with offset
 | |
| # form a pentagon that can be broken down into a top trapezoid and a bottom
 | |
| # rectangle. For the implementation of tril_indices, we need the sizes of
 | |
| # both of these, as well as the length of the top side of the trapezoid.
 | |
| def _get_tril_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]:
 | |
|     if row == 0 or col == 0:
 | |
|         return 0, 0, 0
 | |
| 
 | |
|     m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0)
 | |
|     m_last_row = max(0, min(col, row + offset))
 | |
|     n_row_all = max(0, min(row, row + offset))
 | |
|     n_row_trapezoid = m_last_row - m_first_row + 1
 | |
| 
 | |
|     # Number of elements in top trapezoid
 | |
|     trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2
 | |
|     # Number of elements in bottom rectangle
 | |
|     diff_row = n_row_all - n_row_trapezoid
 | |
|     rectangle_size = max(0, diff_row * col)
 | |
| 
 | |
|     return trapezoid_size, rectangle_size, m_first_row
 | |
| 
 | |
| 
 | |
| def _trilu_checks(
 | |
|     name: str,
 | |
|     row: int,
 | |
|     col: int,
 | |
|     dtype: torch.dtype,
 | |
|     layout: torch.layout,
 | |
|     pin_memory: bool,
 | |
| ):
 | |
|     torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
 | |
|     torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
 | |
|     torch._check(
 | |
|         dtype in (torch.int32, torch.int64),
 | |
|         lambda: f"\"{name}\" not implemented for '{dtype}'",
 | |
|     )
 | |
| 
 | |
| 
 | |
| # This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu
 | |
| @register_decomposition(aten.tril_indices)
 | |
| @out_wrapper()
 | |
| def tril_indices(
 | |
|     row: int,
 | |
|     col: int,
 | |
|     offset: int = 0,
 | |
|     *,
 | |
|     dtype: torch.dtype = torch.long,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: DeviceLikeType = "cpu",
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory)
 | |
| 
 | |
|     trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset)
 | |
|     row_offset = max(0, -offset)
 | |
| 
 | |
|     arange_kw = partial(
 | |
|         torch.arange, layout=layout, device=device, pin_memory=pin_memory
 | |
|     )
 | |
| 
 | |
|     # first we do the indices for top trapezoid
 | |
|     xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
 | |
|     b = m_first_row - 0.5
 | |
|     row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1))
 | |
|     col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5)
 | |
|     row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype)
 | |
|     col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
 | |
| 
 | |
|     # then bottom rectangle
 | |
|     xs2 = arange_kw(0, rectangle_size, dtype=dtype)
 | |
|     row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset)
 | |
|     col_inds2 = xs2 % col
 | |
| 
 | |
|     return torch.stack(
 | |
|         (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2)))
 | |
|     )
 | |
| 
 | |
| 
 | |
| # Similar to _get_tril_sizes above, but here there is a top trapezoid and
 | |
| # a bottom rectangle instead. Note that you can't reduce this to
 | |
| # _get_tril_sizes(col, row, -offset) because that would correspond to
 | |
| # decomposing into a left trapezoid and right rectangle.
 | |
| def _get_triu_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]:
 | |
|     if row == 0 or col == 0:
 | |
|         return 0, 0, 0
 | |
| 
 | |
|     m_first_row = max(0, col - offset) if offset > 0 else col
 | |
| 
 | |
|     # Number of elements in top rectangle
 | |
|     rectangle_size = max(0, min(row, -offset) * col)
 | |
| 
 | |
|     # Number of elements in bottom trapezoid
 | |
|     trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1)
 | |
|     triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril)
 | |
|     trapezoid_size = triu_size - rectangle_size
 | |
| 
 | |
|     return trapezoid_size, rectangle_size, m_first_row
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.triu_indices)
 | |
| @out_wrapper()
 | |
| def triu_indices(
 | |
|     row: int,
 | |
|     col: int,
 | |
|     offset: int = 0,
 | |
|     *,
 | |
|     dtype: torch.dtype = torch.long,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: DeviceLikeType = "cpu",
 | |
|     pin_memory: bool = False,
 | |
| ) -> TensorLikeType:
 | |
|     _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory)
 | |
| 
 | |
|     trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset)
 | |
|     col_offset = max(0, offset)
 | |
| 
 | |
|     arange_kw = partial(
 | |
|         torch.arange, layout=layout, device=device, pin_memory=pin_memory
 | |
|     )
 | |
| 
 | |
|     # indices for top rectangle
 | |
|     xs2 = arange_kw(0, rectangle_size, dtype=dtype)
 | |
|     row_inds2 = xs2 // col
 | |
|     col_inds2 = xs2 % col
 | |
| 
 | |
|     # bottom trapezoid
 | |
|     xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
 | |
|     b = -0.5 - m_first_row
 | |
|     row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1))
 | |
|     col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5)
 | |
|     row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype)
 | |
|     col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)
 | |
| 
 | |
|     if col:
 | |
|         row_inds1 = row_inds1 + (rectangle_size // col)
 | |
|     col_inds1 = col_inds1 + col_offset
 | |
| 
 | |
|     return torch.stack(
 | |
|         (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1)))
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.bucketize)
 | |
| @out_wrapper(exact_dtype=True)
 | |
| def bucketize(
 | |
|     a: TensorOrNumberLikeType,
 | |
|     boundaries: TensorLikeType,
 | |
|     *,
 | |
|     out_int32: bool = False,
 | |
|     right: bool = False,
 | |
| ):
 | |
|     torch._check(
 | |
|         boundaries.dim() == 1,
 | |
|         lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
 | |
|     )
 | |
| 
 | |
|     a = a if isinstance(a, torch.Tensor) else torch.tensor(a)
 | |
|     out_dtype = torch.int32 if out_int32 else torch.int64
 | |
|     n_boundaries = boundaries.shape[-1]
 | |
|     if n_boundaries == 0:
 | |
|         return torch.zeros_like(a)
 | |
|     # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`)
 | |
|     # each element of `a` belongs to. We use binary search to achieve logarithmic complexity,
 | |
|     # but each step of the search is done "in parallel" over all elements of `a`
 | |
|     # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end
 | |
|     start = torch.zeros(a.shape, device=a.device, dtype=torch.int64)
 | |
|     end = start + n_boundaries
 | |
|     # Max depth of the binary search
 | |
|     # Since we can't break out of the loop at different points for different elements of a,
 | |
|     # we just do the max amount of iterations that binary search requires and add condition
 | |
|     # tensor (cond_update below) to stop updating once the search terminates
 | |
| 
 | |
|     # For first iteration through loop we can skip some checks, we have separate implementation
 | |
|     mid = start + (end - start) // 2
 | |
|     mid_val = boundaries[mid]
 | |
|     if right:
 | |
|         cond_mid = mid_val > a
 | |
|     else:
 | |
|         cond_mid = mid_val >= a
 | |
|     start = torch.where(cond_mid, start, mid + 1)
 | |
| 
 | |
|     if n_boundaries > 1:
 | |
|         cond_update = torch.ones_like(a, dtype=torch.bool)
 | |
|         niters = int(math.log2(n_boundaries))
 | |
|         for _ in range(niters):
 | |
|             end = torch.where(cond_mid & cond_update, mid, end)
 | |
|             cond_update = start < end
 | |
|             # start might end up pointing to 1 past the end, we guard against that
 | |
|             mid = torch.where(cond_update, start + (end - start) // 2, 0)
 | |
|             mid_val = boundaries[mid]
 | |
|             # If right is true, the buckets are closed on the *left*
 | |
|             # (i.e., we are doing the equivalent of std::upper_bound in C++)
 | |
|             # Otherwise they are closed on the right (std::lower_bound)
 | |
|             if right:
 | |
|                 cond_mid = mid_val > a
 | |
|             else:
 | |
|                 cond_mid = mid_val >= a
 | |
|             start = torch.where((~cond_mid) & cond_update, mid + 1, start)
 | |
| 
 | |
|     return start.to(dtype=out_dtype)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.cauchy)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self",),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def cauchy(self, median=0, sigma=1, generator=None):
 | |
|     assert generator is None
 | |
|     torch._check(
 | |
|         not utils.is_complex_dtype(self.dtype)
 | |
|         and not utils.is_integer_dtype(self.dtype)
 | |
|         and not utils.is_boolean_dtype(self.dtype),
 | |
|         lambda: f"Cauchy distribution is a continuous probability distribution. \
 | |
|         dtype must be a floating point but you specified {self.dtype}",
 | |
|     )
 | |
|     torch._check(
 | |
|         sigma > 0.0,
 | |
|         lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
 | |
|     )
 | |
|     return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5))
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.exponential)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self",),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def exponential(self, rate=1, generator=None):
 | |
|     assert generator is None
 | |
|     torch._check(
 | |
|         not utils.is_complex_dtype(self.dtype)
 | |
|         and not utils.is_integer_dtype(self.dtype)
 | |
|         and not utils.is_boolean_dtype(self.dtype),
 | |
|         lambda: f"Exponential distribution is a continuous probability distribution. \
 | |
|         dtype must be a floating point but you specified {self.dtype}",
 | |
|     )
 | |
|     torch._check(
 | |
|         rate > 0.0,
 | |
|         lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
 | |
|     )
 | |
| 
 | |
|     uniform_val = torch.rand_like(self)
 | |
| 
 | |
|     # copying numerics of transformation::exponential see comment:
 | |
|     # curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
 | |
|     # we need log to be not 0, and not underflow when converted to half
 | |
|     # fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
 | |
|     epsilon = torch.finfo(uniform_val.dtype).eps / 2
 | |
|     condition = uniform_val >= 1.0 - epsilon
 | |
|     log_uniform = torch.where(condition, -epsilon, torch.log(uniform_val))
 | |
| 
 | |
|     return -1 / rate * log_uniform
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.geometric)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self",),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def geometric(self, p, generator=None):
 | |
|     assert generator is None
 | |
|     # TODO: fix inductor rand_like for integer, bool dtypes
 | |
|     torch._check(
 | |
|         not utils.is_complex_dtype(self.dtype)
 | |
|         and not utils.is_boolean_dtype(self.dtype),
 | |
|         lambda: f"geometric not implemented for {self.dtype}",
 | |
|     )
 | |
|     torch._check(
 | |
|         0 < p and p < 1,
 | |
|         lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
 | |
|     )
 | |
|     return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.log_normal)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self",),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def log_normal(self, mean=1, std=2, generator=None):
 | |
|     assert generator is None
 | |
|     torch._check(
 | |
|         not utils.is_complex_dtype(self.dtype)
 | |
|         and not utils.is_integer_dtype(self.dtype)
 | |
|         and not utils.is_boolean_dtype(self.dtype),
 | |
|         lambda: f"log_normal not implemented for {self.dtype}",
 | |
|     )
 | |
|     torch._check(
 | |
|         0 < std,
 | |
|         lambda: f"log_normal_ expects std > 0.0, but found std={std}",
 | |
|     )
 | |
|     return torch.exp(std * torch.randn_like(self) + mean)
 | |
| 
 | |
| 
 | |
| # TODO: add support for functionalization aten.normal_functional
 | |
| # NOTE: the device and dtype will be ignored when shape is None
 | |
| @register_decomposition(aten.normal)
 | |
| @out_wrapper()
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=(
 | |
|         "mean",
 | |
|         "std",
 | |
|     ),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def normal(
 | |
|     mean=0,
 | |
|     std=1,
 | |
|     size=None,
 | |
|     *,
 | |
|     generator=None,
 | |
|     dtype=None,
 | |
|     layout=None,
 | |
|     device=None,
 | |
|     pin_memory=None,
 | |
| ):
 | |
|     assert layout is None or layout == torch.strided
 | |
| 
 | |
|     if not isinstance(std, TensorLike):
 | |
|         torch._check(
 | |
|             std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}"
 | |
|         )
 | |
| 
 | |
|     if size is None:
 | |
|         tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike))
 | |
|         torch._check(
 | |
|             len(tensors) > 0,
 | |
|             lambda: "normal expects that either mean or std is a tensor, or size is defined",
 | |
|         )
 | |
|         torch._check(
 | |
|             layout is None and pin_memory is None,
 | |
|             lambda: "Cannot pass layout, or pin_memory without size",
 | |
|         )
 | |
| 
 | |
|         size = _broadcast_shapes(*(t.shape for t in tensors))
 | |
|         dtype = tensors[0].dtype
 | |
|         device = tensors[0].device
 | |
|     else:
 | |
|         torch._check(
 | |
|             not isinstance(mean, TensorLike) and not isinstance(std, TensorLike),
 | |
|             lambda: "normal expects mean and std to be scalars when size is defined",
 | |
|         )
 | |
|         dtype = torch.get_default_dtype() if dtype is None else dtype
 | |
|         device = torch.device("cpu") if device is None else device
 | |
| 
 | |
|     normal_samples = prims.normal(
 | |
|         size,
 | |
|         mean=0.0,
 | |
|         std=1.0,
 | |
|         dtype=dtype,
 | |
|         device=device,
 | |
|         requires_grad=False,
 | |
|         generator=generator,
 | |
|     )
 | |
|     return std * normal_samples + mean
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.normal_)
 | |
| def normal_(self, mean=0, std=1, *, generator=None):
 | |
|     return normal(mean, std, self.shape, out=self, generator=generator)
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def rad2deg(self: TensorLikeType):
 | |
|     torch._check(
 | |
|         not utils.is_complex_dtype(self.dtype),
 | |
|         lambda: "rad2deg is not supported for complex tensors.",
 | |
|     )
 | |
|     M_180_PI = 57.295779513082320876798154814105170332405472466564
 | |
|     return self * M_180_PI
 | |
| 
 | |
| 
 | |
| @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
 | |
| def deg2rad(self: TensorLikeType):
 | |
|     torch._check(
 | |
|         not utils.is_complex_dtype(self.dtype),
 | |
|         lambda: "deg2rad is not supported for complex tensors.",
 | |
|     )
 | |
|     M_PI_180 = 0.017453292519943295769236907684886127134428718885417
 | |
|     return self * M_PI_180
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.count_nonzero)
 | |
| @out_wrapper()
 | |
| def count_nonzero(self, dim: Optional[DimsType] = None):
 | |
|     return (self != 0).sum(dim)
 | |
| 
 | |
| 
 | |
| def _dot_check(self, other):
 | |
|     torch._check(
 | |
|         self.dim() == 1 and other.dim() == 1,
 | |
|         lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
 | |
|     )
 | |
| 
 | |
|     torch._check(
 | |
|         self.dtype == other.dtype,
 | |
|         lambda: "dot : expected both vectors to have same dtype, but found "
 | |
|         f"{self.dtype} and {other.dtype}",
 | |
|     )
 | |
| 
 | |
|     def numel_error():
 | |
|         return (
 | |
|             f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the"
 | |
|             f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
 | |
|         )
 | |
| 
 | |
|     torch._check(self.numel() == other.numel(), numel_error)
 | |
| 
 | |
| 
 | |
| def _dot_check_wrapper(fn):
 | |
|     @wraps(fn)
 | |
|     def wrapper(self, other):
 | |
|         _dot_check(self, other)
 | |
|         return fn(self, other)
 | |
| 
 | |
|     return wrapper
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.dot)
 | |
| @out_wrapper(exact_dtype=True)
 | |
| @_dot_check_wrapper
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self", "other"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def dot(self, other):
 | |
|     if self.is_complex():
 | |
|         if self.is_conj():
 | |
|             if other.is_conj():
 | |
|                 return torch.dot(self.conj(), other.conj()).conj()
 | |
|             else:
 | |
|                 return torch.vdot(self.conj(), other)
 | |
|         elif other.is_conj():
 | |
|             return torch.vdot(other.conj(), self)
 | |
| 
 | |
|     return (self * other).sum()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.vdot)
 | |
| @out_wrapper(exact_dtype=True)
 | |
| @_dot_check_wrapper
 | |
| @elementwise_type_promotion_wrapper(
 | |
|     type_promoting_args=("self", "other"),
 | |
|     type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
| )
 | |
| def vdot(self, other):
 | |
|     if not self.is_complex():
 | |
|         return torch.dot(self, other)
 | |
| 
 | |
|     if self.is_conj():
 | |
|         if other.is_conj():
 | |
|             return torch.vdot(other.conj(), self.conj())
 | |
|         else:
 | |
|             return torch.dot(self.conj(), other)
 | |
|     elif other.is_conj():
 | |
|         return torch.dot(self, other.conj()).conj()
 | |
| 
 | |
|     # The decomposition fails if you do self.conj()... not sure why
 | |
|     return (self.conj_physical() * other).sum()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.select_scatter)
 | |
| @out_wrapper()
 | |
| def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int):
 | |
|     dim = utils.canonicalize_dim(x.ndim, dim)
 | |
|     mask_shape = [1] * x.ndim
 | |
|     mask_shape[dim] = -1
 | |
|     if index < 0:
 | |
|         index = index + x.shape[dim]
 | |
|     mask = torch.arange(x.shape[dim], device=x.device).view(mask_shape) == index
 | |
|     src = torch.unsqueeze(src, dim).expand(x.shape)
 | |
|     return torch.where(mask, src, x)
 | |
| 
 | |
| 
 | |
| # inplace
 | |
| abs_ = _make_inplace(abs)
 | |
| acos_ = _make_inplace(acos)
 | |
| acosh_ = _make_inplace(acosh)
 | |
| add_ = _make_inplace(add)
 | |
| addcmul_ = _make_inplace(addcmul)
 | |
| addcdiv_ = _make_inplace(addcdiv)
 | |
| asin_ = _make_inplace(asin)
 | |
| asinh_ = _make_inplace(asinh)
 | |
| atan_ = _make_inplace(atan)
 | |
| atanh_ = _make_inplace(atanh)
 | |
| atan2_ = _make_inplace(atan2)
 | |
| bitwise_and_ = _make_inplace(bitwise_and)
 | |
| bitwise_left_shift_ = _make_inplace(bitwise_left_shift)
 | |
| bitwise_not_ = _make_inplace(bitwise_not)
 | |
| bitwise_or_ = _make_inplace(bitwise_or)
 | |
| bitwise_right_shift_ = _make_inplace(bitwise_right_shift)
 | |
| bitwise_xor_ = _make_inplace(bitwise_xor)
 | |
| ceil_ = _make_inplace(ceil)
 | |
| clamp_ = _make_inplace(clamp)
 | |
| clamp_min_ = _make_inplace(clamp_min)
 | |
| clamp_max_ = _make_inplace(clamp_max)
 | |
| conj_physical_ = _make_inplace(conj_physical)
 | |
| copysign_ = _make_inplace(copysign)
 | |
| cos_ = _make_inplace(cos)
 | |
| cosh_ = _make_inplace(cosh)
 | |
| cumsum_ = _make_inplace(cumsum)
 | |
| cumprod_ = _make_inplace(cumprod)
 | |
| deg2rad_ = _make_inplace(deg2rad)
 | |
| digamma_ = _make_inplace(digamma)
 | |
| div_ = _make_inplace(div)
 | |
| eq_ = _make_inplace(eq)
 | |
| erf_ = _make_inplace(erf)
 | |
| erfc_ = _make_inplace(erfc)
 | |
| erfinv_ = _make_inplace(erfinv)
 | |
| exp_ = _make_inplace(exp)
 | |
| exp2_ = _make_inplace(exp2)
 | |
| expm1_ = _make_inplace(expm1)
 | |
| float_power_ = _make_inplace(float_power)
 | |
| floor_ = _make_inplace(floor)
 | |
| floor_divide_ = _make_inplace(floor_divide)
 | |
| fmod_ = _make_inplace(fmod)
 | |
| frac_ = _make_inplace(frac)
 | |
| gcd_ = _make_inplace(gcd)
 | |
| ge_ = _make_inplace(ge)
 | |
| gt_ = _make_inplace(gt)
 | |
| heaviside_ = _make_inplace(heaviside)
 | |
| hypot_ = _make_inplace(hypot)
 | |
| igamma_ = _make_inplace(igamma)
 | |
| igammac_ = _make_inplace(igammac)
 | |
| i0_ = _make_inplace(i0)
 | |
| lcm_ = _make_inplace(lcm)
 | |
| le_ = _make_inplace(le)
 | |
| lerp_ = _make_inplace(lerp)
 | |
| lgamma_ = _make_inplace(lgamma)
 | |
| log10_ = _make_inplace(log10)
 | |
| log1p_ = _make_inplace(log1p)
 | |
| log2_ = _make_inplace(log2)
 | |
| log_ = _make_inplace(log)
 | |
| logical_and_ = _make_inplace(logical_and)
 | |
| logical_not_ = _make_inplace(logical_not)
 | |
| logical_or_ = _make_inplace(logical_or)
 | |
| logical_xor_ = _make_inplace(logical_xor)
 | |
| lt_ = _make_inplace(lt)
 | |
| mul_ = _make_inplace(mul)
 | |
| mvlgamma_ = _make_inplace(mvlgamma)
 | |
| nan_to_num_ = _make_inplace(nan_to_num)
 | |
| ne_ = _make_inplace(ne)
 | |
| neg_ = _make_inplace(neg)
 | |
| nextafter_ = _make_inplace(nextafter)
 | |
| pow_ = _make_inplace(pow)
 | |
| rad2deg_ = _make_inplace(rad2deg)
 | |
| reciprocal_ = _make_inplace(reciprocal)
 | |
| remainder_ = _make_inplace(remainder)
 | |
| rsqrt_ = _make_inplace(rsqrt)
 | |
| sgn_ = _make_inplace(sgn)
 | |
| sigmoid_ = _make_inplace(sigmoid)
 | |
| sign_ = _make_inplace(sign)
 | |
| sin_ = _make_inplace(sin)
 | |
| sinc_ = _make_inplace(sinc)
 | |
| sinh_ = _make_inplace(sinh)
 | |
| sqrt_ = _make_inplace(sqrt)
 | |
| square_ = _make_inplace(square)
 | |
| sub_ = _make_inplace(sub)
 | |
| tan_ = _make_inplace(tan)
 | |
| tanh_ = _make_inplace(tanh)
 | |
| tril_ = _make_inplace(tril)
 | |
| triu_ = _make_inplace(triu)
 | |
| true_divide_ = _make_inplace(true_divide)
 | |
| trunc_ = _make_inplace(trunc)
 | |
| xlogy_ = _make_inplace(xlogy)
 | |
| cauchy_ = _make_inplace(cauchy)
 | |
| exponential_ = _make_inplace(exponential)
 | |
| geometric_ = _make_inplace(geometric)
 | |
| log_normal_ = _make_inplace(log_normal)
 | |
| zero_ = _make_inplace(zero)
 | |
| 
 | |
| alias_copy = _make_copy_from_view(aten.alias)
 | |
| as_strided_copy = _make_copy_from_view(aten.as_strided)
 | |
| diagonal_copy = _make_copy_from_view(aten.diagonal)
 | |
| expand_copy = _make_copy_from_view(aten.expand)
 | |
| # TODO: This must return a sparse tensor if the input is sparse, but refs have
 | |
| # no sparse support. See narrow_copy_sparse in core.
 | |
| narrow_copy = _make_copy_from_view(aten.narrow)
 | |
| squeeze_copy = _make_copy_from_view(aten.squeeze)
 | |
| permute_copy = _make_copy_from_view(aten.permute)
 | |
| t_copy = _make_copy_from_view(aten.t)
 | |
| transpose_copy = _make_copy_from_view(aten.transpose)
 | |
| unbind_copy = _make_copy_from_view(aten.unbind, return_none_on_out_variant=True)
 | |
| unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
 | |
| view_copy = _make_copy_from_view(aten.view)
 | |
| 
 | |
| 
 | |
| # xref: isStorage in torch/csrc/DynamicTypes.cpp
 | |
| def _isStorage(obj):
 | |
|     return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage))
 | |
| 
 | |
| 
 | |
| # xref: compute_sizes in torch/csrc/utils/tensor_new.cpp
 | |
| def _compute_sizes(seq, scalar_type):
 | |
|     MAX_DIMS = 128
 | |
|     is_storage = _isStorage(seq)
 | |
|     sizes = []
 | |
|     # TODO: this is inaccurate, we actually test PySequence_Check
 | |
|     while isinstance(seq, (list, tuple)):
 | |
|         length = len(seq)
 | |
|         if is_storage:
 | |
|             length //= scalar_type.itemsize
 | |
|         sizes.append(length)
 | |
|         if len(sizes) > MAX_DIMS:
 | |
|             raise ValueError(f"too many dimensions '{type(seq).__name__}'")
 | |
|         if length == 0:
 | |
|             break
 | |
|         try:
 | |
|             handle = seq[0]
 | |
|         except Exception:
 | |
|             raise ValueError(  # noqa: B904
 | |
|                 f"could not determine the shape of object type '{type(seq).__name__}'"
 | |
|             )
 | |
|         seq = handle
 | |
| 
 | |
|     return sizes
 | |
| 
 | |
| 
 | |
| # xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp
 | |
| def _infer_scalar_type(obj):
 | |
|     if isinstance(obj, FloatLike):
 | |
|         return torch.get_default_dtype()
 | |
|     if isinstance(obj, IntLike) and not isinstance(obj, bool):  # careful!
 | |
|         return torch.int64
 | |
|     if isinstance(obj, BoolLike):
 | |
|         return torch.bool
 | |
|     if isinstance(obj, complex):
 | |
|         default_dtype = torch.get_default_dtype()
 | |
|         if default_dtype is torch.float:
 | |
|             return torch.cfloat
 | |
|         elif default_dtype is torch.double:
 | |
|             return torch.cdouble
 | |
|         elif default_dtype is torch.half:
 | |
|             return torch.chalf
 | |
|         else:
 | |
|             raise RuntimeError("invalid default scalar type for complex")
 | |
|     if isinstance(obj, torch.Tensor):
 | |
|         return obj.dtype
 | |
|     if isinstance(obj, str):
 | |
|         raise TypeError(f"new(): invalid data type '{type(obj).__name__}'")
 | |
|     # TODO: this is inaccurate, we actually test PySequence_Check
 | |
|     if isinstance(obj, (list, tuple)):
 | |
|         scalarType = None
 | |
|         length = len(obj)
 | |
|         # match NumPy semantics, except use default tensor type instead of
 | |
|         # double.
 | |
|         if length == 0:
 | |
|             return torch.get_default_dtype()
 | |
|         for i in range(length):
 | |
|             cur_item = obj[i]
 | |
|             # TODO: test this
 | |
|             """
 | |
|             if cur_item is obj:
 | |
|                 raise TypeError("new(): self-referential lists are incompatible")
 | |
|             """
 | |
|             item_scalarType = _infer_scalar_type(cur_item)  # recurse!
 | |
|             if scalarType is not None:
 | |
|                 scalarType = torch.promote_types(scalarType, item_scalarType)
 | |
|             else:
 | |
|                 scalarType = item_scalarType
 | |
|             if scalarType is torch.cdouble:
 | |
|                 # this won't change (unless we hit undefined, but that will
 | |
|                 # fail later)
 | |
|                 return scalarType
 | |
|         return scalarType
 | |
|     raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}")
 | |
| 
 | |
| 
 | |
| # Analogous to recursive_store
 | |
| # xref: recursive_store in torch/csrc/utils/tensor_new.cpp
 | |
| def _recursive_build(
 | |
|     scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType]
 | |
| ):
 | |
|     if isinstance(obj, Tensor) and obj.numel() == 1:
 | |
|         return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
 | |
|     elif isinstance(obj, Tensor):
 | |
|         # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode
 | |
|         # >>> torch.tensor([torch.randn(2)])
 | |
|         # ValueError: only one element tensors can be converted to Python scalars
 | |
|         #
 | |
|         # But it is possible with a NumPy array
 | |
|         # >>> torch.tensor([np.random.uniform(size=(2,))]).shape
 | |
|         # torch.Size([1, 2])
 | |
|         return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
 | |
|     elif isinstance(obj, Number):
 | |
|         return torch.scalar_tensor(obj, dtype=scalarType)
 | |
| 
 | |
|     # seq can be a list of tensors
 | |
|     seq = obj
 | |
|     return (
 | |
|         torch.empty(0)
 | |
|         if not seq
 | |
|         else torch.stack([_recursive_build(scalarType, item) for item in seq])
 | |
|     )
 | |
| 
 | |
| 
 | |
| # xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp
 | |
| def _internal_new_from_data(
 | |
|     options,
 | |
|     scalar_type,
 | |
|     device_opt,
 | |
|     data,
 | |
|     copy_variables,
 | |
|     copy_numpy,
 | |
|     type_inference,
 | |
|     pin_memory=False,
 | |
| ):
 | |
|     if isinstance(data, torch.Tensor):
 | |
|         torch._check(
 | |
|             not pin_memory, lambda: "Can't pin tensor constructed from a variable"
 | |
|         )
 | |
|         var = data
 | |
|         if copy_variables:
 | |
|             var = var.detach()
 | |
|         inferred_scalar_type = var.dtype if type_inference else scalar_type
 | |
|         device = device_opt if device_opt is not None else var.device
 | |
|         return var.to(
 | |
|             device=device,
 | |
|             dtype=inferred_scalar_type,
 | |
|             non_blocking=False,
 | |
|             copy=copy_variables,
 | |
|         )
 | |
| 
 | |
|     # TODO
 | |
|     if hasattr(data, "__cuda_array_interface__"):
 | |
|         return NotImplemented
 | |
| 
 | |
|     # TODO: test for numpy input with PyArray_Check
 | |
| 
 | |
|     device = device_opt if device_opt is not None else options["device"]
 | |
|     inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type
 | |
| 
 | |
|     # NB: Don't need to avoid tracing, as we aren't going to do any manual
 | |
|     # pointer filling tricks
 | |
|     if _isStorage(data):
 | |
|         return NotImplemented
 | |
|     else:
 | |
|         if torch.device(device).type == "meta":
 | |
|             return NotImplemented
 | |
| 
 | |
|         # In the C implementation, we would directly start poking the memory
 | |
|         # of a freshly allocated CPU tensor.  Here, we're going to do an
 | |
|         # alternate, heinously slow implementation: turn each individual
 | |
|         # scalar into a tensor, and then repeatedly cat them together
 | |
|         tensor = _recursive_build(inferred_scalar_type, data)
 | |
| 
 | |
|         tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False)
 | |
| 
 | |
|     # NB: lift_fresh is not needed, because we built the tensor from scalars
 | |
|     # guaranteeing a fresh tensor in this case
 | |
|     return tensor
 | |
| 
 | |
| 
 | |
| # xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp
 | |
| def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False):
 | |
|     # TODO (or not): support names kwarg
 | |
|     if isinstance(data, torch.Tensor):
 | |
|         warnings.warn(
 | |
|             "To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() "
 | |
|             "or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor)",
 | |
|             UserWarning,
 | |
|             stacklevel=2,
 | |
|         )
 | |
|     type_inference = dtype is None
 | |
|     new_tensor = _internal_new_from_data(
 | |
|         # device="cpu" because that's what you get with torch.tensor(2) no
 | |
|         # device by default
 | |
|         {"device": "cpu"},  # TODO: use torch.get_default_tensor_type
 | |
|         dtype if dtype is not None else torch.get_default_dtype(),
 | |
|         device,
 | |
|         data,
 | |
|         copy_variables=True,
 | |
|         copy_numpy=True,
 | |
|         type_inference=type_inference,
 | |
|         pin_memory=pin_memory,
 | |
|     )
 | |
|     new_tensor.detach_()
 | |
|     if requires_grad:
 | |
|         new_tensor.requires_grad_(requires_grad)
 | |
|     return new_tensor
 | |
| 
 | |
| 
 | |
| # Views
 | |
| # We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function
 | |
| # given that it does not reshape the input (it just copies the result into it)
 | |
| 
 | |
| # squeeze_ = _make_inplace(squeeze)
 | |
| # t_ = _make_inplace(t)
 | |
| # transpose_ = _make_inplace(transpose)
 | |
| # unsqueeze_ = _make_inplace(unsqueeze)
 | |
| 
 | |
| 
 | |
| import torch._refs._conversions
 | |
| import torch._refs.fft
 | |
| import torch._refs.linalg
 | |
| import torch._refs.nn.functional
 | |
| import torch._refs.special
 |