mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Pull Request resolved: https://github.com/pytorch/pytorch/pull/164753 Approved by: https://github.com/aorenste, https://github.com/mlazos ghstack dependencies: #164664, #164665, #164667, #164668
		
			
				
	
	
		
			5320 lines
		
	
	
		
			176 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			5320 lines
		
	
	
		
			176 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # mypy: allow-untyped-decorators
 | |
| # mypy: allow-untyped-defs
 | |
| import functools
 | |
| import itertools
 | |
| import numbers
 | |
| import operator
 | |
| import sys
 | |
| from collections.abc import Callable, Iterable
 | |
| from contextlib import nullcontext
 | |
| from enum import Enum
 | |
| from functools import partial, reduce
 | |
| from itertools import chain, product
 | |
| from typing import Any, cast, Optional, Union
 | |
| 
 | |
| import torch
 | |
| import torch._meta_registrations
 | |
| import torch._prims as prims
 | |
| import torch._prims_common as utils
 | |
| import torch.nn.functional as F
 | |
| from torch import sym_float, sym_int, Tensor
 | |
| from torch._decomp import register_decomposition
 | |
| from torch._higher_order_ops.out_dtype import out_dtype
 | |
| from torch._prims_common import (
 | |
|     IntLike,
 | |
|     NumberType,
 | |
|     suggest_memory_format,
 | |
|     TensorLike,
 | |
|     TensorSequenceType,
 | |
| )
 | |
| from torch._prims_common.wrappers import (
 | |
|     _maybe_convert_to_dtype,
 | |
|     _maybe_resize_out,
 | |
|     _safe_copy_out,
 | |
|     out_wrapper,
 | |
| )
 | |
| from torch.utils import _pytree as pytree
 | |
| from torch.utils._pytree import tree_map
 | |
| 
 | |
| 
 | |
| DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined]
 | |
| 
 | |
| # None of these functions are publicly accessible; get at them
 | |
| # from torch._decomps
 | |
| __all__: list[str] = []
 | |
| 
 | |
| aten = torch._ops.ops.aten
 | |
| 
 | |
| 
 | |
| class Reduction(Enum):
 | |
|     NONE = 0
 | |
|     MEAN = 1
 | |
|     SUM = 2
 | |
| 
 | |
| 
 | |
| # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
 | |
| # We're currently reusing ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
 | |
| # Will need to validate the non-elementwise uses
 | |
| def type_casts(
 | |
|     f: Callable,
 | |
|     type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
 | |
|     compute_dtype_only: bool = False,
 | |
|     include_non_tensor_args: bool = False,
 | |
| ):
 | |
|     @functools.wraps(f)
 | |
|     def inner(*args, **kwargs):
 | |
|         allowed_types = (
 | |
|             (Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,)
 | |
|         )  # type: ignore[arg-type]
 | |
|         flat_args = [
 | |
|             x
 | |
|             for x in pytree.arg_tree_leaves(*args, **kwargs)
 | |
|             if isinstance(x, allowed_types)
 | |
|         ]
 | |
|         computation_dtype, result_dtype = utils.elementwise_dtypes(
 | |
|             *flat_args, type_promotion_kind=type_promotion
 | |
|         )
 | |
| 
 | |
|         # TODO: pretty sure this is not quite right
 | |
|         def increase_prec(x):
 | |
|             if isinstance(x, Tensor):
 | |
|                 return x.to(computation_dtype)
 | |
|             else:
 | |
|                 return x
 | |
| 
 | |
|         def decrease_prec(x):
 | |
|             if isinstance(x, Tensor):
 | |
|                 return x.to(result_dtype)
 | |
|             else:
 | |
|                 return x
 | |
| 
 | |
|         r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
 | |
|         if compute_dtype_only:
 | |
|             return r
 | |
|         else:
 | |
|             return tree_map(decrease_prec, r)
 | |
| 
 | |
|     return inner
 | |
| 
 | |
| 
 | |
| compute_only_pw_cast_for_opmath = partial(
 | |
|     type_casts,
 | |
|     type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     compute_dtype_only=True,
 | |
| )
 | |
| pw_cast_for_opmath = partial(
 | |
|     type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
 | |
| )
 | |
| pw_cast_for_opmath_non_tensor_args = partial(
 | |
|     type_casts,
 | |
|     type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
 | |
|     include_non_tensor_args=True,
 | |
| )
 | |
| pw_cast_for_int_to_real = partial(
 | |
|     type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
 | |
| )
 | |
| 
 | |
| 
 | |
| # This expands x until x.dim() == dim. Might be useful as an operator
 | |
| def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor:
 | |
|     for _ in range(dim - x.dim()):
 | |
|         x = x.unsqueeze(-1)
 | |
|     return x
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.tanh_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def tanh_backward(out_grad: Tensor, y: Tensor):
 | |
|     return out_grad * (1 - y * y).conj_physical()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.sigmoid_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def sigmoid_backward(out_grad: Tensor, y: Tensor):
 | |
|     return out_grad * (y * (1 - y)).conj_physical()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.softplus_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
 | |
|     z = (x * beta).exp()
 | |
|     return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.elu_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def elu_backward(
 | |
|     grad_output: Tensor,
 | |
|     alpha: float,
 | |
|     scale: float,
 | |
|     input_scale: float,
 | |
|     is_result: bool,
 | |
|     self_or_result: Tensor,
 | |
| ):
 | |
|     negcoef = alpha * scale
 | |
|     poscoef = scale
 | |
|     negiptcoef = input_scale
 | |
|     if is_result:
 | |
|         return torch.where(
 | |
|             self_or_result <= 0,
 | |
|             grad_output * negiptcoef * (self_or_result + negcoef),
 | |
|             grad_output * poscoef,
 | |
|         )
 | |
|     else:
 | |
|         return torch.where(
 | |
|             self_or_result <= 0,
 | |
|             grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
 | |
|             grad_output * poscoef,
 | |
|         )
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.fill.Scalar])
 | |
| def fill_scalar(self, value):
 | |
|     return torch.full_like(self, value)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.fill.Tensor])
 | |
| def fill_tensor(self, value: Tensor):
 | |
|     torch._check(
 | |
|         value.dim() == 0,
 | |
|         lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
 | |
|     )
 | |
|     return aten.copy(self, value)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.hardsigmoid)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def hardsigmoid(self: Tensor) -> Tensor:
 | |
|     return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.hardsigmoid_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
 | |
|     return torch.where(
 | |
|         (self > -3.0) & (self < 3.0),
 | |
|         grad_output * (1.0 / 6.0),
 | |
|         0.0,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.hardtanh_backward)
 | |
| @out_wrapper("grad_input")
 | |
| def hardtanh_backward(
 | |
|     grad_output: Tensor, self: Tensor, min_val: float, max_val: float
 | |
| ):
 | |
|     return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.hardswish)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def hardswish(self: Tensor) -> Tensor:
 | |
|     return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.hardswish_backward)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
 | |
|     return torch.where(
 | |
|         self <= -3,
 | |
|         0.0,
 | |
|         torch.where(self < 3, grad_output * ((self / 3) + 0.5), grad_output),
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.threshold_backward)
 | |
| @out_wrapper("grad_input")
 | |
| def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
 | |
|     return torch.where(self <= threshold, 0, grad_output)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.leaky_relu_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def leaky_relu_backward(
 | |
|     grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
 | |
| ):
 | |
|     return torch.where(self > 0, grad_output, grad_output * negative_slope)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.gelu_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
 | |
|     M_SQRT2 = 1.41421356237309504880
 | |
|     M_SQRT1_2 = 0.70710678118654752440
 | |
|     M_2_SQRTPI = 1.12837916709551257390
 | |
|     if approximate == "tanh":
 | |
|         kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
 | |
|         kKappa = 0.044715
 | |
|         x_sq = self * self
 | |
|         x_cube = x_sq * self
 | |
|         inner = kBeta * (self + kKappa * x_cube)
 | |
|         tanh_inner = torch.tanh(inner)
 | |
| 
 | |
|         left = 0.5 * self
 | |
|         right = 1 + tanh_inner
 | |
| 
 | |
|         left_derivative = 0.5 * right
 | |
| 
 | |
|         tanh_derivative = 1 - tanh_inner * tanh_inner
 | |
|         inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
 | |
|         right_derivative = left * tanh_derivative * inner_derivative
 | |
| 
 | |
|         return grad * (left_derivative + right_derivative)
 | |
|     else:
 | |
|         kAlpha = M_SQRT1_2
 | |
|         kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
 | |
|         cdf = 0.5 * (1 + torch.erf(self * kAlpha))
 | |
|         pdf = kBeta * torch.exp(self * self * -0.5)
 | |
|         return grad * (cdf + self * pdf)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.mish_backward)
 | |
| @pw_cast_for_opmath
 | |
| def mish_backward(grad_output: Tensor, input: Tensor):
 | |
|     input_tanh_softplus = torch.tanh(F.softplus(input))
 | |
|     input_sigmoid = torch.sigmoid(input)
 | |
|     out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
 | |
|     return grad_output * (input_tanh_softplus + out)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.silu)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def silu(self: Tensor) -> Tensor:
 | |
|     return self * torch.sigmoid(self)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.silu_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
 | |
|     sigmoid = 1 / (1 + torch.exp(-self))
 | |
|     return grad_output * sigmoid * (1 + self * (1 - sigmoid))
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._prelu_kernel)
 | |
| def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor:
 | |
|     return torch.where(self > 0, self, weight * self)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._prelu_kernel_backward)
 | |
| def _prelu_kernel_backward(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     weight: Tensor,
 | |
| ) -> tuple[Tensor, Tensor]:
 | |
|     input_grad = torch.where(self > 0, grad_output, weight * grad_output)
 | |
|     weight_grad = torch.where(self > 0, 0.0, self * grad_output)
 | |
|     return (input_grad, weight_grad)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.rrelu_with_noise_backward)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def rrelu_with_noise_backward(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     noise: Tensor,
 | |
|     lower: float,
 | |
|     upper: float,
 | |
|     training: bool,
 | |
|     self_is_result: bool,
 | |
| ) -> Tensor:
 | |
|     if training and upper - lower > 1e-6:
 | |
|         return grad_output.mul(noise)
 | |
|     else:
 | |
|         negative_slope = (lower + upper) / 2
 | |
|         return aten.leaky_relu_backward(
 | |
|             grad_output, self, negative_slope, self_is_result
 | |
|         )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.log_sigmoid_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
 | |
|     in_negative = self < 0
 | |
|     max_deriv = torch.where(in_negative, 1, 0)
 | |
|     sign = torch.where(in_negative, 1, -1)
 | |
|     z = torch.exp(-torch.abs(self))
 | |
|     return grad_output * (max_deriv - sign * (z / (1 + z)))
 | |
|     # CPU has a special formula that uses buffer, but disabled for convenience sake
 | |
|     # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
 | |
| 
 | |
| 
 | |
| def apply_loss_reduction(loss: Tensor, reduction: int):
 | |
|     if reduction == Reduction.MEAN.value:
 | |
|         return torch.mean(loss)
 | |
|     elif reduction == Reduction.SUM.value:
 | |
|         return torch.sum(loss)
 | |
|     else:
 | |
|         return loss
 | |
| 
 | |
| 
 | |
| def to_real_dtype(dtype: torch.dtype):
 | |
|     if dtype == torch.complex32:
 | |
|         return torch.float16
 | |
|     elif dtype == torch.complex64:
 | |
|         return torch.float32
 | |
|     elif dtype == torch.complex128:
 | |
|         return torch.float64
 | |
| 
 | |
| 
 | |
| # TODO: None of these loss castings are quite correct, see
 | |
| # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
 | |
| # perform the pointwise portion in opmath, but don't maintain it between the
 | |
| # pointwise portion and the reduction
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.mse_loss)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def mse_loss(
 | |
|     self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
 | |
| ) -> Tensor:
 | |
|     # pyrefly: ignore  # unsupported-operation
 | |
|     loss = (self - target) ** 2
 | |
|     return apply_loss_reduction(loss, reduction)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.mse_loss_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def mse_loss_backward(
 | |
|     grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
 | |
| ):
 | |
|     norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
 | |
|     return norm * (input - target) * grad_output
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._safe_softmax)
 | |
| def safe_softmax(self, dim, dtype=None):
 | |
|     out = torch.softmax(self, dim=dim, dtype=dtype)
 | |
|     masked = self.eq(float("-inf"))
 | |
|     masked_rows = torch.all(masked, dim=dim, keepdim=True)
 | |
|     zeros = torch.zeros_like(out)
 | |
|     return torch.where(masked_rows, zeros, out)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.smooth_l1_loss)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def smooth_l1_loss(
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     reduction: int = Reduction.MEAN.value,
 | |
|     beta: float = 1.0,
 | |
| ):
 | |
|     loss = (self - target).abs()
 | |
|     # pyrefly: ignore  # unsupported-operation
 | |
|     loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
 | |
|     return apply_loss_reduction(loss, reduction)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.smooth_l1_loss_backward.default)
 | |
| @pw_cast_for_opmath
 | |
| def smooth_l1_loss_backward(
 | |
|     grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float
 | |
| ):
 | |
|     norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
 | |
|     x = self - target
 | |
|     abs_x = torch.abs(x)
 | |
|     norm_grad = norm * grad_output
 | |
|     return torch.where(
 | |
|         abs_x < beta,
 | |
|         norm_grad * x / beta,
 | |
|         norm_grad * torch.sign(x),
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.smooth_l1_loss_backward.grad_input)
 | |
| @pw_cast_for_opmath
 | |
| def smooth_l1_loss_backward_out(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     reduction: int,
 | |
|     beta: float,
 | |
|     grad_input: Tensor,
 | |
| ):
 | |
|     result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta)
 | |
|     _maybe_resize_out(grad_input, result.shape)
 | |
|     return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.huber_loss_backward.default)
 | |
| @pw_cast_for_opmath
 | |
| def huber_loss_backward(
 | |
|     grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
 | |
| ):
 | |
|     norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
 | |
|     x = self - target
 | |
|     return torch.where(
 | |
|         x < -delta,
 | |
|         -norm * grad_output * delta,
 | |
|         torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
 | |
|     )
 | |
| 
 | |
| 
 | |
| # We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input'
 | |
| @register_decomposition(aten.huber_loss_backward.out)
 | |
| @pw_cast_for_opmath
 | |
| def huber_loss_backward_out(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     reduction: int,
 | |
|     delta: float,
 | |
|     grad_input: Tensor,
 | |
| ):
 | |
|     result = huber_loss_backward(grad_output, self, target, reduction, delta)
 | |
|     _maybe_resize_out(grad_input, result.shape)
 | |
|     return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
 | |
| 
 | |
| 
 | |
| def _nll_loss_backward(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     reduction: int,
 | |
|     ignore_index: int,
 | |
|     total_weight: Tensor,
 | |
| ) -> Tensor:
 | |
|     channel_dim = 0 if self.dim() < 2 else 1
 | |
|     if reduction == Reduction.MEAN.value:
 | |
|         grad_output = grad_output / total_weight
 | |
| 
 | |
|     target = target.unsqueeze(channel_dim)
 | |
|     safe_target = torch.where(target != ignore_index, target, 0)
 | |
|     grad_input = torch.zeros_like(self)
 | |
|     grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
 | |
| 
 | |
|     if grad_input.dim() > grad_output.dim() > 0:
 | |
|         grad_output = grad_output.unsqueeze(channel_dim)
 | |
| 
 | |
|     if weight is not None:
 | |
|         new_shape = [1 for _ in range(self.dim())]
 | |
|         new_shape[channel_dim] = weight.shape[0]
 | |
|         weight = weight.reshape(new_shape)
 | |
|         grad_output = grad_output * weight
 | |
| 
 | |
|     grad_output = torch.where(target != ignore_index, grad_output, 0)
 | |
| 
 | |
|     return grad_input * grad_output
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.glu_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor:
 | |
|     assert self.dim() > 0, "glu does not support 0-dimensional tensors"
 | |
|     wrap_dim = utils.canonicalize_dim(self.dim(), dim)
 | |
|     nIn = self.size(wrap_dim)
 | |
|     assert nIn % 2 == 0, (
 | |
|         f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}"
 | |
|     )
 | |
|     inputSize = nIn // 2
 | |
|     firstHalf = self.narrow(wrap_dim, 0, inputSize)
 | |
|     secondHalf = self.narrow(wrap_dim, inputSize, inputSize)
 | |
|     gradInputFirstHalf = torch.sigmoid(secondHalf)
 | |
|     gradInputSecondHalf = (
 | |
|         (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output
 | |
|     )
 | |
|     gradInputFirstHalf = gradInputFirstHalf * grad_output
 | |
|     return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.nll_loss_backward)
 | |
| @out_wrapper("grad_input")
 | |
| def nll_loss_backward(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     reduction: int,
 | |
|     ignore_index: int,
 | |
|     total_weight: Tensor,
 | |
| ) -> Tensor:
 | |
|     assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
 | |
|     assert target.dim() <= 1, (
 | |
|         "0D or 1D target tensor expected, multi-target not supported"
 | |
|     )
 | |
| 
 | |
|     no_batch_dim = self.dim() == 1 and target.dim() == 0
 | |
|     assert no_batch_dim or (self.shape[0] == target.shape[0]), (
 | |
|         f"size mismatch (got input: {self.shape}, target: {target.shape})"
 | |
|     )
 | |
|     assert total_weight.numel() == 1, (
 | |
|         "expected total_weight to be a single element tensor, got: ",
 | |
|         f"{total_weight.shape} ({total_weight.numel()} elements)",
 | |
|     )
 | |
| 
 | |
|     assert weight is None or weight.numel() == self.shape[-1], (
 | |
|         "weight tensor should be defined either for all or no classes"
 | |
|     )
 | |
| 
 | |
|     if reduction == Reduction.NONE.value and self.dim() == 2:
 | |
|         assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
 | |
|             f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
 | |
|             f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
 | |
|         )
 | |
|     else:
 | |
|         assert grad_output.dim() <= 1 and grad_output.numel() == 1, (
 | |
|             f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
 | |
|         )
 | |
| 
 | |
|     return _nll_loss_backward(
 | |
|         grad_output, self, target, weight, reduction, ignore_index, total_weight
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.nll_loss2d_backward)
 | |
| @out_wrapper("grad_input")
 | |
| def nll_loss2d_backward(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     reduction: int,
 | |
|     ignore_index: int,
 | |
|     total_weight: Tensor,
 | |
| ) -> Tensor:
 | |
|     assert self.dim() == 4, (
 | |
|         f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
 | |
|     )
 | |
| 
 | |
|     assert target.dim() == 3, (
 | |
|         f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
 | |
|     )
 | |
| 
 | |
|     assert (
 | |
|         self.shape[0] == target.shape[0]
 | |
|         and self.shape[2] == target.shape[1]
 | |
|         and self.shape[3] == target.shape[2]
 | |
|     ), f"size mismatch (got input: {self.shape}, target: {target.shape}"
 | |
| 
 | |
|     assert total_weight.numel() == 1, (
 | |
|         "expected total_weight to be a single element tensor, "
 | |
|         f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
 | |
|     )
 | |
| 
 | |
|     return _nll_loss_backward(
 | |
|         grad_output, self, target, weight, reduction, ignore_index, total_weight
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.binary_cross_entropy)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def binary_cross_entropy(
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     weight: Optional[Tensor] = None,
 | |
|     reduction: int = Reduction.MEAN.value,
 | |
| ) -> Tensor:
 | |
|     # We cannot currently model this without introducing data-dependent control flow
 | |
|     # TORCH_CHECK(
 | |
|     #     (input_val >= 0) && (input_val <= 1),
 | |
|     #     "all elements of input should be between 0 and 1"
 | |
|     # )
 | |
|     loss = (target - 1) * torch.maximum(
 | |
|         torch.log1p(-self), self.new_full((), -100)
 | |
|     ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
 | |
|     if weight is not None:
 | |
|         loss = loss * weight
 | |
|     return apply_loss_reduction(loss, reduction)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.binary_cross_entropy_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def binary_cross_entropy_backward(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     weight: Optional[Tensor] = None,
 | |
|     reduction: int = Reduction.MEAN.value,
 | |
| ) -> Tensor:
 | |
|     EPSILON = 1e-12
 | |
|     result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
 | |
|     if weight is not None:
 | |
|         result = result * weight
 | |
|     if reduction == Reduction.MEAN.value:
 | |
|         result = result / self.numel()
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.soft_margin_loss)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def soft_margin_loss(
 | |
|     input: Tensor,
 | |
|     target: Tensor,
 | |
|     reduction: int = Reduction.MEAN.value,
 | |
| ) -> Tensor:
 | |
|     loss = torch.log1p(torch.exp(-input * target))
 | |
|     return apply_loss_reduction(loss, reduction)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.soft_margin_loss_backward)
 | |
| @out_wrapper("grad_input")
 | |
| @pw_cast_for_opmath
 | |
| def soft_margin_loss_backward(
 | |
|     grad_output: Tensor,
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     reduction: int = Reduction.MEAN.value,
 | |
| ) -> Tensor:
 | |
|     grad_input = target * grad_output * (torch.sigmoid(target * self) - 1)
 | |
|     if reduction == Reduction.MEAN.value:
 | |
|         grad_input = grad_input / self.numel()
 | |
|     return grad_input
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.dist)
 | |
| @out_wrapper()
 | |
| def dist(input: Tensor, other: Tensor, p: float = 2):
 | |
|     return aten.norm(input - other, p=p)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._euclidean_dist)
 | |
| @out_wrapper()
 | |
| def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
 | |
|     x1_norm = x1.pow(2).sum(-1, True)
 | |
|     x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
 | |
|     x2_norm = x2.pow(2).sum(-1, True)
 | |
|     x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
 | |
|     x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
 | |
|     x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
 | |
|     result = x1_.matmul(x2_.mT)
 | |
|     return result.clamp_min(0).sqrt()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.slice_backward)
 | |
| @out_wrapper()
 | |
| def slice_backward(
 | |
|     grad_output: Tensor,
 | |
|     input_sizes: list[int],
 | |
|     dim: int,
 | |
|     start: int,
 | |
|     end: int,
 | |
|     step: int,
 | |
| ):
 | |
|     grad_input = grad_output.new_zeros(input_sizes)
 | |
|     return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.slice.Tensor)
 | |
| def slice_forward(
 | |
|     # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1
 | |
|     self: Tensor,
 | |
|     dim: int = 0,
 | |
|     start: Optional[int] = None,
 | |
|     end: Optional[int] = None,
 | |
|     step: int = 1,
 | |
| ):
 | |
|     from torch.fx.experimental.symbolic_shapes import statically_known_true
 | |
| 
 | |
|     ndim = self.dim()
 | |
|     if ndim == 0:
 | |
|         raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
 | |
|     dim = utils.canonicalize_dim(self.dim(), dim)
 | |
|     sizes = list(self.size())
 | |
|     strides = list(self.stride())
 | |
| 
 | |
|     if step <= 0:
 | |
|         raise RuntimeError("slice step must be positive")
 | |
| 
 | |
|     start_val = start if start is not None else 0
 | |
|     end_val = end if end is not None else sys.maxsize  # 2^63 - 1
 | |
| 
 | |
|     if start_val < 0:
 | |
|         start_val += sizes[dim]
 | |
| 
 | |
|     if end_val < 0:
 | |
|         end_val += sizes[dim]
 | |
| 
 | |
|     if start_val < 0:
 | |
|         start_val = 0
 | |
|     elif start_val > sizes[dim]:
 | |
|         start_val = sizes[dim]
 | |
| 
 | |
|     if statically_known_true(end_val == sys.maxsize):
 | |
|         end_val = sizes[dim]
 | |
|     elif end_val < start_val:
 | |
|         end_val = start_val
 | |
|     elif end_val > sizes[dim]:
 | |
|         end_val = sizes[dim]
 | |
| 
 | |
|     storage_offset = self.storage_offset() + start_val * strides[dim]
 | |
|     len = end_val - start_val
 | |
|     sizes[dim] = (len + step - 1) // step
 | |
|     strides[dim] *= step
 | |
| 
 | |
|     if self.is_quantized:
 | |
|         raise NotImplementedError(
 | |
|             "Slice decomposition for quantized tensors aren't implemented"
 | |
|         )
 | |
|     else:
 | |
|         return self.as_strided(sizes, strides, storage_offset)
 | |
| 
 | |
| 
 | |
| def _normalize_start_end(
 | |
|     x: Tensor, dim: int, start: Optional[int], end: Optional[int]
 | |
| ) -> tuple[int, int]:
 | |
|     """
 | |
|     Normalize start and end such that both are in the range
 | |
|     [0, x.get_size()[dim]] and start <= end.
 | |
|     """
 | |
|     dim_size = x.shape[dim]
 | |
| 
 | |
|     def clamp_wrap(val, lower, upper, default) -> int:
 | |
|         if val is None:
 | |
|             return default
 | |
|         if val < 0:
 | |
|             val = val + dim_size
 | |
|         return min(max(val, lower), upper)
 | |
| 
 | |
|     start = clamp_wrap(start, 0, dim_size, 0)
 | |
|     end = clamp_wrap(end, start, dim_size, dim_size)
 | |
|     return start, end
 | |
| 
 | |
| 
 | |
| # This is not in torch._refs because aten.index used by
 | |
| # aten._unsafe_masked_index does not have a decomposition.
 | |
| @register_decomposition(aten.slice_scatter)
 | |
| @out_wrapper()
 | |
| def slice_scatter(
 | |
|     input: Tensor,
 | |
|     src: Tensor,
 | |
|     dim: int = 0,
 | |
|     start: Optional[int] = None,
 | |
|     end: Optional[int] = None,
 | |
|     step: int = 1,
 | |
| ):
 | |
|     dim = utils.canonicalize_dim(input.ndim, dim)
 | |
|     dim_size = input.shape[dim]
 | |
|     start, end = _normalize_start_end(input, dim, start, end)
 | |
| 
 | |
|     src_size = list(input.shape)
 | |
|     src_size[dim] = (end - start + (step - 1)) // step
 | |
|     src = src.expand(src_size)
 | |
| 
 | |
|     if start == 0 and end == dim_size and step == 1:
 | |
|         return src.clone()
 | |
| 
 | |
|     indices: list[Optional[Tensor]] = [None] * input.dim()
 | |
|     idx = torch.arange(dim_size, device=input.device)
 | |
|     indices[dim] = (idx - start) // step
 | |
| 
 | |
|     mask = torch.ones(dim_size, device=input.device, dtype=torch.bool)
 | |
|     if start != 0:
 | |
|         mask = torch.logical_and(mask, idx >= start)
 | |
| 
 | |
|     if end != dim_size:
 | |
|         mask = torch.logical_and(mask, idx < end)
 | |
| 
 | |
|     if step != 1:
 | |
|         mask = torch.logical_and(mask, (idx - start) % step == 0)
 | |
| 
 | |
|     mask_shape = [1] * input.dim()
 | |
|     mask_shape[dim] = -1
 | |
|     mask = mask.view(mask_shape)
 | |
|     return aten.where(mask, aten._unsafe_masked_index(src, mask, indices, 0), input)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.select_backward)
 | |
| @out_wrapper()
 | |
| def select_backward(grad_output: Tensor, input_sizes: list[int], dim: int, index: int):
 | |
|     grad_input = grad_output.new_zeros(input_sizes)
 | |
|     return torch.select_scatter(grad_input, grad_output, dim, index)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.diagonal_backward)
 | |
| @out_wrapper()
 | |
| def diagonal_backward(
 | |
|     grad_output: Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int
 | |
| ):
 | |
|     grad_input = grad_output.new_zeros(input_sizes)
 | |
|     return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
 | |
| 
 | |
| 
 | |
| def _cast_grad_to_input_dtype(
 | |
|     grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype
 | |
| ):
 | |
|     if grad_output.dtype != input_dtype:
 | |
|         grad_input = grad_input.to(input_dtype)
 | |
|     return grad_input
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._softmax_backward_data)
 | |
| @out_wrapper("grad_input")
 | |
| @compute_only_pw_cast_for_opmath
 | |
| def _softmax_backward_data(
 | |
|     grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
 | |
| ):
 | |
|     new_grad_output = grad_output * output
 | |
|     grad_input = new_grad_output - output * torch.sum(
 | |
|         new_grad_output, dim=dim, keepdim=True
 | |
|     )
 | |
| 
 | |
|     # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
 | |
|     # if grad_output.device == torch.device("cpu"):
 | |
|     #     return grad_input.contiguous()
 | |
| 
 | |
|     return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._log_softmax_backward_data)
 | |
| @out_wrapper()
 | |
| @compute_only_pw_cast_for_opmath
 | |
| def _log_softmax_backward_data(
 | |
|     grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
 | |
| ):
 | |
|     grad_input = grad_output - torch.exp(output) * torch.sum(
 | |
|         grad_output, dim=dim, keepdim=True
 | |
|     )
 | |
|     return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype)
 | |
| 
 | |
| 
 | |
| def _im2col_col2im_indices_along_dim(
 | |
|     input_d, kernel_d, dilation_d, padding_d, stride_d, device
 | |
| ):
 | |
|     """Utility function to implement im2col and col2im"""
 | |
|     blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1)
 | |
| 
 | |
|     arange_kw = partial(torch.arange, dtype=torch.int64, device=device)
 | |
| 
 | |
|     # Stride kernel over input and find starting indices along dim d
 | |
|     blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0)
 | |
| 
 | |
|     # Apply dilation on kernel and find its indices along dim d
 | |
|     kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1)
 | |
| 
 | |
|     # Broadcast and add kernel starting positions (indices) with
 | |
|     # kernel_grid along dim d, to get block indices along dim d
 | |
|     return blocks_d_indices + kernel_grid
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.im2col)
 | |
| @out_wrapper()
 | |
| def im2col(
 | |
|     input: Tensor,
 | |
|     kernel_size: list[int],
 | |
|     dilation: list[int],
 | |
|     padding: list[int],
 | |
|     stride: list[int],
 | |
| ) -> Tensor:
 | |
|     torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
 | |
|     torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
 | |
|     torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
 | |
|     torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
 | |
| 
 | |
|     def check_positive(param, param_name, strict=True):
 | |
|         cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
 | |
|         torch._check(
 | |
|             cond, lambda: f"{param_name} should be greater than zero, but got {param}"
 | |
|         )
 | |
| 
 | |
|     check_positive(kernel_size, "kernel_size")
 | |
|     check_positive(dilation, "dilation")
 | |
|     check_positive(dilation, "padding", strict=False)
 | |
|     check_positive(stride, "stride")
 | |
| 
 | |
|     shape = input.shape
 | |
|     ndim = len(shape)
 | |
|     torch._check(
 | |
|         ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
 | |
|         lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
 | |
|         f"and non-zero dimensions, but got: {tuple(shape)}",
 | |
|     )
 | |
|     output_size = tuple(
 | |
|         1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
 | |
|         for out, pad, dil, ker, st in zip(
 | |
|             shape[-2:], padding, dilation, kernel_size, stride
 | |
|         )
 | |
|     )
 | |
|     torch._check(
 | |
|         all(c > 0 for c in output_size),
 | |
|         lambda: f"Given an input with spatial size {tuple(shape[-2:])}, "
 | |
|         f"kernel_size={kernel_size}, dilation={dilation}, "
 | |
|         f"padding={padding}, stride={stride}, "
 | |
|         "the calculated shape of the array of sliding blocks "
 | |
|         f"is {output_size}, but its components must be at least one.",
 | |
|     )
 | |
|     batched_input = ndim == 4
 | |
|     if not batched_input:
 | |
|         input = input.unsqueeze(0)
 | |
| 
 | |
|     batch_dim, channel_dim, input_h, input_w = input.shape
 | |
| 
 | |
|     stride_h, stride_w = stride
 | |
|     padding_h, padding_w = padding
 | |
|     dilation_h, dilation_w = dilation
 | |
|     kernel_h, kernel_w = kernel_size
 | |
| 
 | |
|     blocks_row_indices = _im2col_col2im_indices_along_dim(
 | |
|         input_h, kernel_h, dilation_h, padding_h, stride_h, input.device
 | |
|     )
 | |
|     blocks_col_indices = _im2col_col2im_indices_along_dim(
 | |
|         input_w, kernel_w, dilation_w, padding_w, stride_w, input.device
 | |
|     )
 | |
| 
 | |
|     # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom)
 | |
|     # ugh
 | |
|     padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h))
 | |
| 
 | |
|     blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1)
 | |
|     output = padded_input[:, :, blocks_row_indices, blocks_col_indices]
 | |
|     output = output.permute(0, 1, 2, 4, 3, 5)
 | |
|     num_blocks_row = blocks_row_indices.size(1)
 | |
|     num_blocks_col = blocks_col_indices.size(1)
 | |
|     output = output.reshape(
 | |
|         batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col
 | |
|     )
 | |
| 
 | |
|     if not batched_input:
 | |
|         output = output.squeeze(0)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.col2im)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def col2im(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     kernel_size: list[int],
 | |
|     dilation: list[int],
 | |
|     padding: list[int],
 | |
|     stride: list[int],
 | |
| ) -> Tensor:
 | |
|     torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
 | |
|     torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
 | |
|     torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
 | |
|     torch._check(len(padding) == 2, lambda: "only 2D padding supported")
 | |
|     torch._check(len(stride) == 2, lambda: "only 2D stride supported")
 | |
| 
 | |
|     def check_positive(param, param_name, strict=True):
 | |
|         cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
 | |
|         torch._check(
 | |
|             cond, lambda: f"{param_name} should be greater than zero, but got {param}"
 | |
|         )
 | |
| 
 | |
|     check_positive(kernel_size, "kernel_size")
 | |
|     check_positive(dilation, "dilation")
 | |
|     check_positive(padding, "padding", strict=False)
 | |
|     check_positive(stride, "stride")
 | |
|     check_positive(output_size, "output_size")
 | |
| 
 | |
|     shape = input.shape
 | |
|     ndim = len(shape)
 | |
|     torch._check(
 | |
|         ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
 | |
|         lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
 | |
|         f"and non-zero dimensions, but got: {tuple(shape)}",
 | |
|     )
 | |
|     prod_kernel_size = kernel_size[0] * kernel_size[1]
 | |
|     torch._check(
 | |
|         shape[-2] % prod_kernel_size == 0,
 | |
|         lambda: "Expected size of input's first non-batch dimension to be divisible by the "
 | |
|         f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
 | |
|         f"kernel_size={kernel_size}",
 | |
|     )
 | |
|     col = [
 | |
|         1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
 | |
|         for out, pad, dil, ker, st in zip(
 | |
|             output_size, padding, dilation, kernel_size, stride
 | |
|         )
 | |
|     ]
 | |
|     L = col[0] * col[1]
 | |
|     torch._check(
 | |
|         shape[-1] == L,
 | |
|         lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
 | |
|         f"dilation={dilation}, padding={padding}, stride={stride}, "
 | |
|         f"expected input.size(-1) to be {L} but got {shape[-1]}.",
 | |
|     )
 | |
|     torch._check(
 | |
|         L > 0,
 | |
|         lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
 | |
|         f"dilation={dilation}, padding={padding}, stride={stride}, "
 | |
|         f"expected input.size(-1) to be {L} but got {shape[-1]}.",
 | |
|     )
 | |
|     batched_input = ndim == 3
 | |
|     if not batched_input:
 | |
|         input = input.unsqueeze(0)
 | |
| 
 | |
|     shape = input.shape
 | |
| 
 | |
|     out_h, out_w = output_size
 | |
|     stride_h, stride_w = stride
 | |
|     padding_h, padding_w = padding
 | |
|     dilation_h, dilation_w = dilation
 | |
|     kernel_h, kernel_w = kernel_size
 | |
| 
 | |
|     # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand
 | |
|     input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col)
 | |
|     input = input.permute(0, 1, 2, 4, 3, 5)
 | |
| 
 | |
|     indices_row = _im2col_col2im_indices_along_dim(
 | |
|         out_h, kernel_h, dilation_h, padding_h, stride_h, input.device
 | |
|     )
 | |
|     indices_row = _unsqueeze_to_dim(indices_row, 4)
 | |
|     indices_col = _im2col_col2im_indices_along_dim(
 | |
|         out_w, kernel_w, dilation_w, padding_w, stride_w, input.device
 | |
|     )
 | |
| 
 | |
|     output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)]
 | |
|     output = input.new_zeros(
 | |
|         [shape[0], shape[1] // prod(kernel_size)] + output_padded_size
 | |
|     )
 | |
|     idx = (None, None, indices_row, indices_col)
 | |
|     output = aten._unsafe_index_put(output, idx, input, accumulate=True)
 | |
|     output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h))
 | |
| 
 | |
|     if not batched_input:
 | |
|         output = output.squeeze(0)
 | |
|     return output
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.native_dropout_backward)
 | |
| @out_wrapper()
 | |
| def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
 | |
|     # According to the CUDA kernel implementation we should have this test;
 | |
|     # but it seems to fail tests!
 | |
|     # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
 | |
| 
 | |
|     # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
 | |
|     # This different from TensorIterator's behavior
 | |
|     r = (grad_output * (mask.type_as(grad_output) * scale)).clone(
 | |
|         memory_format=utils.suggest_memory_format(grad_output)
 | |
|     )
 | |
|     return r
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.unfold_backward)
 | |
| @out_wrapper()
 | |
| def unfold_backward(
 | |
|     grad: Tensor, input_size: list[int], dimension: int, size: int, step: int
 | |
| ) -> Tensor:
 | |
|     if len(input_size) == 0:
 | |
|         return torch.squeeze_copy(grad, 0)
 | |
|     dim = utils.canonicalize_dim(len(input_size), dimension)
 | |
|     idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32)
 | |
|     idx = idx.unfold(0, size, step).flatten()
 | |
|     grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1)
 | |
|     # nb. At the moment this generates two kernels in triton
 | |
|     # It could potentially be fused into one call to scatter_reduce,
 | |
|     # in the case step <= size provided scatter_reduce generates 1 kernel
 | |
|     grad_input = grad.new_zeros(input_size)
 | |
|     index = (None,) * dim + (idx,)
 | |
|     return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.logit_backward.default)
 | |
| @pw_cast_for_opmath
 | |
| def logit_backward(
 | |
|     grad_output: Tensor, self: Tensor, eps: Optional[float] = None
 | |
| ) -> Tensor:
 | |
|     if eps is not None:
 | |
|         lo = eps
 | |
|         hi = 1.0 - lo
 | |
|         return torch.where(
 | |
|             torch.logical_and(self >= lo, self <= hi),
 | |
|             grad_output / (self * (1.0 - self)),
 | |
|             0.0,
 | |
|         )
 | |
|     else:
 | |
|         return torch.where(
 | |
|             torch.logical_and(self >= 0.0, self <= 1.0),
 | |
|             grad_output / (self * (1.0 - self)),
 | |
|             self.new_full((), float("nan")),
 | |
|         )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.dropout)
 | |
| @aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.dropout.default.py_impl(DispatchKey.Autograd)
 | |
| def dropout(input: Tensor, p: float, train: Optional[bool]):
 | |
|     if train and p != 0:
 | |
|         return aten.native_dropout(input, p, train)[0]
 | |
|     else:
 | |
|         return input.clone()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.native_dropout)
 | |
| @out_wrapper("out0", "out1")
 | |
| def native_dropout(input: Tensor, p: float, train: Optional[bool]):
 | |
|     if train and p != 0:
 | |
|         if p == 1:
 | |
|             return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool))
 | |
|         if not input.dtype.is_floating_point:
 | |
|             raise RuntimeError(
 | |
|                 "result type Float can't be cast to the desired output type Long"
 | |
|             )
 | |
|         bool_mask = torch.rand_like(input) > p
 | |
|         res = bool_mask * input * float(1.0 / (1.0 - p))
 | |
|         return (res, bool_mask)
 | |
|     else:
 | |
|         return (input, torch.ones_like(input, dtype=torch.bool))
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._softmax)
 | |
| @out_wrapper()
 | |
| def _softmax(x: Tensor, dim: int, half_to_float: bool):
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_or_false
 | |
| 
 | |
|     # eager softmax returns a contiguous tensor. Ensure that decomp also returns
 | |
|     # a contiguous tensor.
 | |
|     x = x.contiguous()
 | |
|     if half_to_float:
 | |
|         assert x.dtype == torch.half
 | |
|     computation_dtype, result_dtype = utils.elementwise_dtypes(
 | |
|         x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
 | |
|     )
 | |
|     x = x.to(computation_dtype)
 | |
|     if guard_or_false(x.numel() == 0):
 | |
|         unnormalized = torch.exp(x)
 | |
|     else:
 | |
|         x_max = torch.amax(x, dim, keepdim=True)
 | |
|         unnormalized = torch.exp(x - x_max)
 | |
|     result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
 | |
|     if not half_to_float:
 | |
|         result = result.to(result_dtype)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._log_softmax)
 | |
| @out_wrapper(exact_dtype=True)
 | |
| def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_or_false
 | |
| 
 | |
|     # eager log_softmax returns a contiguous tensor. Ensure that decomp also
 | |
|     # returns a contiguous tensor.
 | |
|     x = x.contiguous()
 | |
|     if half_to_float:
 | |
|         assert x.dtype == torch.half
 | |
|     computation_dtype, result_dtype = utils.elementwise_dtypes(
 | |
|         x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
 | |
|     )
 | |
|     x = x.to(computation_dtype)
 | |
|     if guard_or_false(x.numel() == 0):
 | |
|         shifted = x
 | |
|     else:
 | |
|         x_max = torch.amax(x, dim, keepdim=True)
 | |
|         shifted = x - x_max
 | |
|     shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
 | |
|     result = shifted - shifted_logsumexp
 | |
|     if not half_to_float:
 | |
|         result = result.to(result_dtype)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.embedding)
 | |
| @out_wrapper()
 | |
| def embedding(
 | |
|     weight: Tensor,
 | |
|     indices: Tensor,
 | |
|     padding_idx: int = -1,
 | |
|     scale_grad_by_freq: bool = False,
 | |
|     sparse: bool = False,
 | |
| ) -> Tensor:
 | |
|     assert weight.dim() == 2, "'weight' must be 2-D"
 | |
|     # Nb. scale_grad_by_freq is not used in the forward
 | |
|     if indices.ndim <= 1:
 | |
|         # We need this one as weight[indices] calls item() in these cases
 | |
|         out = weight.index_select(0, indices)
 | |
|         if indices.ndim == 0:
 | |
|             out = out.squeeze(0)
 | |
|         return out
 | |
|     else:
 | |
|         return weight[indices]
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.embedding_dense_backward)
 | |
| @out_wrapper()
 | |
| def embedding_dense_backward(
 | |
|     grad_output: Tensor,
 | |
|     indices: Tensor,
 | |
|     num_weights: int,
 | |
|     padding_idx: int,
 | |
|     scale_grad_by_freq: bool,
 | |
| ):
 | |
|     computation_dtype, result_dtype = utils.elementwise_dtypes(
 | |
|         grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
 | |
|     )
 | |
|     grad_output = grad_output.to(computation_dtype)
 | |
|     indices = _maybe_convert_to_dtype(indices, torch.long)  # type: ignore[assignment]
 | |
|     if scale_grad_by_freq:
 | |
|         counts = indices.new_zeros((num_weights,))
 | |
|         ones = torch.ones_like(indices)
 | |
|         counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True)
 | |
|         grad_weights_scale = counts[indices]
 | |
|         grad_output = grad_output / grad_weights_scale.unsqueeze(-1)
 | |
| 
 | |
|     mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim)
 | |
|     grad = grad_output.masked_fill(mask, 0)
 | |
|     grad_weight = grad_output.new_zeros(
 | |
|         (num_weights,) + grad_output.shape[indices.ndim :]
 | |
|     )
 | |
|     return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to(
 | |
|         result_dtype
 | |
|     )
 | |
| 
 | |
| 
 | |
| def prod(x: list[int]):
 | |
|     r = 1
 | |
|     for i in x:
 | |
|         r *= i
 | |
|     return r
 | |
| 
 | |
| 
 | |
| def _pad_chunk(
 | |
|     tensors: list[Tensor],
 | |
|     dim: int,
 | |
|     num_chunks: int,
 | |
| ) -> list[Tensor]:
 | |
|     padded_tensors = []
 | |
|     for tensor in tensors:
 | |
|         tensor_size = tensor.size()
 | |
|         pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks
 | |
|         if pad_along_dim != tensor_size[dim]:
 | |
|             # Use aten.constant_pad_nd instead of copy_ for functionalization
 | |
|             pad = [0] * 2 * (tensor.ndim - dim - 1) + [
 | |
|                 0,
 | |
|                 pad_along_dim - tensor_size[dim],
 | |
|             ]
 | |
|             tensor = aten.constant_pad_nd(tensor, pad, 0)
 | |
|         view_size = tensor_size[:dim] + torch.Size([num_chunks, -1])
 | |
|         padded_tensors.append(tensor.reshape(view_size))
 | |
|     return padded_tensors
 | |
| 
 | |
| 
 | |
| def have_same_ndims(tensors: list[Tensor]):
 | |
|     ndim = tensors[0].ndim
 | |
|     for tensor in tensors:
 | |
|         if tensor.ndim != ndim:
 | |
|             return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def leading_dimension_matches(tensors: list[Tensor], dim: int):
 | |
|     leading_dim_sizes = tensors[0].size()[:dim]
 | |
|     for tensor in tensors:
 | |
|         torch._check(
 | |
|             tensor.size()[:dim] == leading_dim_sizes,
 | |
|             lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors",
 | |
|         )
 | |
| 
 | |
| 
 | |
| def _preprocess_chunk_cat_inputs(
 | |
|     tensors: list[Tensor],
 | |
|     dim: int,
 | |
|     num_chunks: int,
 | |
| ):
 | |
|     torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks")
 | |
|     torch._check(
 | |
|         len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list"
 | |
|     )
 | |
|     expected_dtype = tensors[0].dtype
 | |
|     expected_device = tensors[0].device
 | |
|     for tensor in tensors:
 | |
|         torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor")
 | |
|         torch._check(
 | |
|             tensor.dtype == expected_dtype,
 | |
|             lambda: "_chunk_cat expects all input tensors with the same dtype",
 | |
|         )
 | |
|         torch._check(
 | |
|             tensor.device == expected_device,
 | |
|             lambda: "_chunk_cat expects all inputs tensors on the same device",
 | |
|         )
 | |
|     if have_same_ndims(tensors):
 | |
|         dim = utils.canonicalize_dim(tensors[0].dim(), dim)
 | |
|     else:
 | |
|         torch._check(
 | |
|             dim >= 0,
 | |
|             lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims",
 | |
|         )
 | |
|         for tensor in tensors:
 | |
|             torch._check(
 | |
|                 dim < tensor.ndim,
 | |
|                 lambda: "_chunk_cat expects dim < ndim for all input tensors",
 | |
|             )
 | |
|     leading_dimension_matches(tensors, dim)
 | |
|     return dim
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out])
 | |
| def _chunk_cat(
 | |
|     tensors: list[Tensor],
 | |
|     dim: int,
 | |
|     num_chunks: int,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> Tensor:
 | |
|     dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks)
 | |
|     padded_tensors = _pad_chunk(tensors, dim, num_chunks)
 | |
|     if out is None:
 | |
|         return torch.cat(padded_tensors, dim + 1)
 | |
|     else:
 | |
|         torch.cat(padded_tensors, dim + 1, out=out)
 | |
|         return out
 | |
| 
 | |
| 
 | |
| # out_wrapper currently does not allow optional outputs
 | |
| @register_decomposition(
 | |
|     [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out]
 | |
| )
 | |
| def split_with_sizes_copy(
 | |
|     self: Tensor,
 | |
|     split_sizes: list[int],
 | |
|     dim: int = 0,
 | |
|     out: Optional[list[Tensor]] = None,
 | |
| ) -> Optional[list[Tensor]]:
 | |
|     splits = aten.split_with_sizes(self, split_sizes, dim=dim)
 | |
|     if out is None:
 | |
|         return [s.clone(memory_format=torch.contiguous_format) for s in splits]
 | |
|     else:
 | |
|         for output, split in zip(out, splits):
 | |
|             _maybe_resize_out(output, split.shape)
 | |
|             _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True)
 | |
|         return None
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.unsafe_split.Tensor)
 | |
| def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
 | |
|     return aten.split.Tensor(input, split_size, dim)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.unsafe_split_with_sizes.default)
 | |
| def unsafe_split_with_sizes(
 | |
|     input: Tensor, split_sizes: list[int], dim: int = 0
 | |
| ) -> tuple[Tensor, ...]:
 | |
|     return aten.split_with_sizes.default(input, split_sizes, dim)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.split.Tensor)
 | |
| def split(self: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
 | |
|     input_sizes = self.shape
 | |
|     dim_size = input_sizes[dim]
 | |
|     if split_size == 0:
 | |
|         assert dim_size == 0
 | |
|         return (self.detach(),)
 | |
|     chunks = (dim_size + split_size - 1) // split_size
 | |
| 
 | |
|     # Avoid importing sympy at a module level
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_int
 | |
| 
 | |
|     chunks = guard_int(chunks)
 | |
|     split_sizes = [split_size for i in range(chunks)]
 | |
|     split_sizes[-1] = split_size - (split_size * chunks - dim_size)
 | |
|     return torch.split(self, split_sizes, dim)
 | |
| 
 | |
| 
 | |
| @aten.tensor_split.tensor_indices_or_sections.py_impl(
 | |
|     DispatchKey.CompositeImplicitAutograd
 | |
| )
 | |
| def tensor_split_tensor_indices_or_sections_py_impl(
 | |
|     self: Tensor,
 | |
|     tensor_indices_or_sections: Tensor,
 | |
|     dim: int = 0,
 | |
| ) -> tuple[Tensor, ...]:
 | |
|     assert tensor_indices_or_sections.device.type == "cpu"
 | |
|     assert tensor_indices_or_sections.dtype == torch.int64
 | |
|     split_dim = tensor_indices_or_sections.dim()
 | |
|     torch._check(
 | |
|         split_dim == 1 or split_dim == 0,
 | |
|         lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional "
 | |
|         f"or one-dimensional tensor, but got a tensor with {split_dim} dims",
 | |
|     )
 | |
|     if split_dim == 0:
 | |
|         sections = tensor_indices_or_sections.item()
 | |
|         assert isinstance(sections, IntLike)
 | |
|         return self.tensor_split(sections, dim)
 | |
|     else:
 | |
|         ctx = nullcontext
 | |
|         if (fake_mode := torch._guards.detect_fake_mode()) and (
 | |
|             shape_env := fake_mode.shape_env
 | |
|         ):
 | |
|             ctx = shape_env.ignore_fresh_unbacked_symbols  # type: ignore[assignment]
 | |
|         # In fake tensor prop, we end up calling slice() with these unbacked indices.
 | |
|         # Because slice has flexible semantics, the unbacked handling generates new output sizes
 | |
|         # for each slice, effectively clobbering over these index symbols.
 | |
|         # To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these.
 | |
|         with ctx():
 | |
|             indices = [i.item() for i in tensor_indices_or_sections]
 | |
|         # WARNING: Tempted to torch._check(x>0) on the indices here?  You
 | |
|         # can't: tensor_split works with negative values in indices:
 | |
|         #
 | |
|         # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5]))
 | |
|         # (tensor([ 0.3540,  2.1074, -0.8507,  1.1639,  0.3055]), tensor([]),
 | |
|         # tensor([-0.4285,  1.0692, -0.1776,  0.9362,  1.6143]))
 | |
|         #
 | |
|         # Sorry, I don't make the rules.  Explicitly do the item call in user
 | |
|         # code if you KNOW that they are non-negative.
 | |
|         return self.tensor_split(indices, dim)
 | |
| 
 | |
| 
 | |
| # TODO: this doesn't appear to have enough precision in bfloat16
 | |
| @register_decomposition(aten.addmm)
 | |
| @out_wrapper(exact_dtype=True)
 | |
| @pw_cast_for_opmath
 | |
| def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
 | |
|     if not self.is_floating_point() and not self.is_complex():
 | |
|         beta = int(beta)
 | |
|         alpha = int(alpha)
 | |
|     out = alpha * torch.mm(mat1, mat2)
 | |
|     if beta == 0:
 | |
|         return out
 | |
| 
 | |
|     # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition.
 | |
|     # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided.
 | |
|     # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition.
 | |
|     # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input.
 | |
|     # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases.
 | |
|     # This implementation is not ideal, and we should revisit this when we have a better solution.
 | |
|     return out + beta * self
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._addmm_activation)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def _addmm_activation(
 | |
|     self: Tensor,
 | |
|     mat1: Tensor,
 | |
|     mat2: Tensor,
 | |
|     beta: int = 1,
 | |
|     alpha: int = 1,
 | |
|     use_gelu: bool = False,
 | |
| ):
 | |
|     out = addmm(self, mat1, mat2, beta, alpha)
 | |
|     if use_gelu:
 | |
|         if self.is_cuda:
 | |
|             return aten.gelu(out, approximate="tanh")
 | |
|         else:
 | |
|             return aten.gelu(out)
 | |
|     return aten.relu(out)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.addmv)
 | |
| @out_wrapper(exact_dtype=True)
 | |
| @pw_cast_for_opmath
 | |
| def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1):
 | |
|     if not self.is_floating_point() and not self.is_complex():
 | |
|         beta = int(beta)
 | |
|         alpha = int(alpha)
 | |
|     out = alpha * torch.mv(mat1, vec)
 | |
|     if beta == 0:
 | |
|         return out
 | |
|     if out.numel() == 0:  # handle empty matrix
 | |
|         return beta * self
 | |
|     return out + beta * self
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.native_group_norm_backward.default)
 | |
| @pw_cast_for_opmath
 | |
| def native_group_norm_backward(
 | |
|     grad_output: Tensor,
 | |
|     input: Tensor,
 | |
|     mean: Tensor,
 | |
|     rstd: Tensor,
 | |
|     gamma: Optional[Tensor],
 | |
|     N: int,
 | |
|     C: int,
 | |
|     HxW: int,
 | |
|     group: int,
 | |
|     output_mask: list[bool],
 | |
| ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
 | |
|     utils.check_same_device(
 | |
|         grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
 | |
|     )
 | |
|     utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
 | |
|     utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
 | |
|     torch._check(
 | |
|         input.numel() == N * C * HxW,
 | |
|         lambda: f"Expect input to have {N * C * HxW} elements",
 | |
|     )
 | |
|     torch._check(
 | |
|         mean.shape == (N, group),
 | |
|         lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
 | |
|     )
 | |
|     torch._check(
 | |
|         gamma is None or gamma.numel() == C,
 | |
|         lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
 | |
|     )
 | |
| 
 | |
|     cpg = C // group
 | |
|     torch._check(
 | |
|         C == cpg * group,
 | |
|         lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
 | |
|     )
 | |
| 
 | |
|     # Compute Internal gradients
 | |
|     ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2])
 | |
|     db = grad_output.view(N, C, HxW).sum(dim=[2])
 | |
| 
 | |
|     d_input: Optional[Tensor] = None
 | |
|     d_gamma: Optional[Tensor] = None
 | |
|     d_bias: Optional[Tensor] = None
 | |
|     if output_mask[0]:
 | |
|         s = 1.0 / (HxW * cpg)
 | |
|         if gamma is not None:
 | |
|             ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
 | |
|             db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
 | |
|             c1 = torch.mul(
 | |
|                 rstd.unsqueeze(-1),
 | |
|                 gamma.reshape(1, group, cpg),
 | |
|             )
 | |
|         else:
 | |
|             ds_val = ds.reshape(N, group, cpg).sum(2)
 | |
|             db_val = db.reshape(N, group, cpg).sum(2)
 | |
|             c1 = torch.mul(
 | |
|                 rstd.unsqueeze(-1),
 | |
|                 torch.ones((1, group, cpg), device=rstd.device),
 | |
|             )
 | |
|         c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s
 | |
|         c3 = -c2 * mean - db_val * rstd * s
 | |
| 
 | |
|         c1 = c1.unsqueeze(-1)
 | |
|         c2 = _unsqueeze_to_dim(c2, 4)
 | |
|         c3 = _unsqueeze_to_dim(c3, 4)
 | |
|         d_input = (
 | |
|             torch.mul(grad_output.reshape(N, group, cpg, HxW), c1)
 | |
|             + torch.mul(input.reshape(N, group, cpg, HxW), c2)
 | |
|             + c3
 | |
|         )
 | |
|         d_input = d_input.reshape(input.shape).to(input.dtype)
 | |
|     if output_mask[1]:
 | |
|         d_gamma = (
 | |
|             (
 | |
|                 (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1))
 | |
|                 * rstd.unsqueeze(-1)
 | |
|             )
 | |
|             .sum(dim=[0])
 | |
|             .reshape(C)
 | |
|         )
 | |
|     if output_mask[2]:
 | |
|         d_bias = db.sum(dim=[0])
 | |
| 
 | |
|     return (d_input, d_gamma, d_bias)
 | |
| 
 | |
| 
 | |
| # out_wrapper currently does not allow optional outputs
 | |
| @register_decomposition(aten.native_group_norm_backward.out)
 | |
| def native_group_norm_backward_out(
 | |
|     grad_output: Tensor,
 | |
|     input: Tensor,
 | |
|     mean: Tensor,
 | |
|     rstd: Tensor,
 | |
|     gamma: Optional[Tensor],
 | |
|     N: int,
 | |
|     C: int,
 | |
|     HxW: int,
 | |
|     group: int,
 | |
|     output_mask: list[bool],
 | |
|     *,
 | |
|     out0: torch.Tensor,
 | |
|     out1: torch.Tensor,
 | |
|     out2: torch.Tensor,
 | |
| ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
 | |
|     result = native_group_norm_backward(
 | |
|         grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask
 | |
|     )
 | |
|     grad_input = (out0, out1, out2)
 | |
|     for i, r in enumerate(result):
 | |
|         if r is not None:
 | |
|             _maybe_resize_out(grad_input[i], r.shape)
 | |
|             _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
 | |
| 
 | |
|     return grad_input
 | |
| 
 | |
| 
 | |
| def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
 | |
|     if x is not None:
 | |
|         return x.to(dtype)
 | |
|     return x
 | |
| 
 | |
| 
 | |
| # TODO: Take a closer look at the type promotion semantics
 | |
| @register_decomposition(aten.native_layer_norm_backward.default)
 | |
| def native_layer_norm_backward(
 | |
|     grad_out: Tensor,
 | |
|     input: Tensor,
 | |
|     normalized_shape: list[int],
 | |
|     mean: Tensor,
 | |
|     rstd: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     output_mask: list[bool],
 | |
| ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
 | |
|     input_shape = input.shape
 | |
|     input_ndim = input.dim()
 | |
|     computation_dtype = utils.get_computation_dtype(input.dtype)
 | |
|     grad_out_cast, input_cast, weight_cast, bias_cast = (
 | |
|         x.to(computation_dtype, memory_format=torch.contiguous_format)
 | |
|         if x is not None
 | |
|         else x
 | |
|         for x in (grad_out, input, weight, bias)
 | |
|     )
 | |
|     assert grad_out_cast is not None
 | |
| 
 | |
|     axis = input_ndim - len(normalized_shape)
 | |
|     inner_dims = input_shape[axis:]
 | |
|     outer_dims = input_shape[:axis]
 | |
|     inner_dim_indices: list[int] = []
 | |
|     outer_dim_indices: list[int] = []
 | |
|     for i in range(input_ndim):
 | |
|         if i >= axis:
 | |
|             inner_dim_indices.append(i)
 | |
|         else:
 | |
|             outer_dim_indices.append(i)
 | |
| 
 | |
|     N = prod(inner_dims)  # type: ignore[arg-type]
 | |
|     M = prod(outer_dims)  # type: ignore[arg-type]
 | |
|     from torch.fx.experimental.symbolic_shapes import statically_known_true
 | |
| 
 | |
|     if statically_known_true(M == 0) or statically_known_true(N == 0):
 | |
|         return (
 | |
|             input.new_zeros(input_shape) if output_mask[0] else None,
 | |
|             input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
 | |
|             input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
 | |
|         )
 | |
|     mean = _unsqueeze_to_dim(mean, input_cast.dim())  # type: ignore[union-attr]
 | |
|     rstd = _unsqueeze_to_dim(rstd, input_cast.dim())  # type: ignore[union-attr]
 | |
|     assert input_cast is not None
 | |
|     x_hat = (input_cast - mean) * rstd
 | |
|     if weight_cast is not None:
 | |
|         grad_x_hat = grad_out_cast * weight_cast
 | |
|     else:
 | |
|         grad_x_hat = grad_out_cast
 | |
|     a = grad_x_hat * N
 | |
|     b = torch.sum(grad_x_hat, inner_dim_indices, True)
 | |
|     c1 = torch.mul(grad_x_hat, x_hat)
 | |
|     c2 = torch.sum(c1, inner_dim_indices, True)
 | |
|     c3 = torch.mul(x_hat, c2)
 | |
| 
 | |
|     inner = a - b - c3
 | |
|     d_input: Optional[Tensor] = None
 | |
|     d_weight: Optional[Tensor] = None
 | |
|     d_bias: Optional[Tensor] = None
 | |
|     if output_mask[0]:
 | |
|         d_input = (rstd / N) * inner
 | |
| 
 | |
|     if output_mask[1] and weight_cast is not None:
 | |
|         if len(outer_dim_indices) > 0:
 | |
|             d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False)
 | |
|         else:
 | |
|             d_weight = grad_out_cast * x_hat
 | |
| 
 | |
|     if output_mask[2] and bias_cast is not None:
 | |
|         if len(outer_dim_indices) > 0:
 | |
|             d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
 | |
|         else:
 | |
|             d_bias = grad_out_cast.clone()
 | |
| 
 | |
|     return (
 | |
|         _maybe_cast(d_input, input.dtype),
 | |
|         _maybe_cast(d_weight, weight.dtype if weight is not None else None),
 | |
|         _maybe_cast(d_bias, bias.dtype if bias is not None else None),
 | |
|     )
 | |
| 
 | |
| 
 | |
| # out_wrapper currently does not allow optional outputs
 | |
| @register_decomposition(aten.native_layer_norm_backward.out)
 | |
| def native_layer_norm_backward_out(
 | |
|     grad_out: Tensor,
 | |
|     input: Tensor,
 | |
|     normalized_shape: list[int],
 | |
|     mean: Tensor,
 | |
|     rstd: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     output_mask: list[bool],
 | |
|     *,
 | |
|     out0: torch.Tensor,
 | |
|     out1: torch.Tensor,
 | |
|     out2: torch.Tensor,
 | |
| ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
 | |
|     result = native_layer_norm_backward(
 | |
|         grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask
 | |
|     )
 | |
|     grad_input = (out0, out1, out2)
 | |
|     for i, r in enumerate(result):
 | |
|         if r is not None:
 | |
|             _maybe_resize_out(grad_input[i], r.shape)
 | |
|             _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
 | |
| 
 | |
|     return grad_input
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._fused_rms_norm_backward.default)
 | |
| def _fused_rms_norm_backward(
 | |
|     grad_out: Tensor,
 | |
|     input: Tensor,
 | |
|     normalized_shape: list[int],
 | |
|     rstd: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     output_mask: list[bool],
 | |
| ) -> tuple[Optional[Tensor], Optional[Tensor]]:
 | |
|     input_shape = input.shape
 | |
|     input_ndim = input.dim()
 | |
|     computation_dtype = utils.get_computation_dtype(input.dtype)
 | |
| 
 | |
|     grad_out_cast = grad_out.to(
 | |
|         computation_dtype, memory_format=torch.contiguous_format
 | |
|     )
 | |
|     input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format)
 | |
|     weight_cast = (
 | |
|         weight.to(computation_dtype, memory_format=torch.contiguous_format)
 | |
|         if weight is not None
 | |
|         else None
 | |
|     )
 | |
|     assert grad_out_cast is not None
 | |
| 
 | |
|     axis = input_ndim - len(normalized_shape)
 | |
|     inner_dims = input_shape[axis:]
 | |
|     outer_dims = input_shape[:axis]
 | |
|     inner_dim_indices: list[int] = []
 | |
|     outer_dim_indices: list[int] = []
 | |
|     for i in range(input_ndim):
 | |
|         if i >= axis:
 | |
|             inner_dim_indices.append(i)
 | |
|         else:
 | |
|             outer_dim_indices.append(i)
 | |
| 
 | |
|     N = prod(inner_dims)  # type: ignore[arg-type]
 | |
|     M = prod(outer_dims)  # type: ignore[arg-type]
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_or_false
 | |
| 
 | |
|     if guard_or_false(M == 0) or guard_or_false(N == 0):
 | |
|         return (
 | |
|             input.new_zeros(input_shape) if output_mask[0] else None,
 | |
|             input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
 | |
|         )
 | |
| 
 | |
|     rstd = _unsqueeze_to_dim(rstd, input_cast.dim())  # type: ignore[union-attr]
 | |
|     if weight_cast is not None:
 | |
|         grad_x_hat = grad_out_cast * weight_cast
 | |
|     else:
 | |
|         grad_x_hat = grad_out_cast
 | |
| 
 | |
|     d_input: Optional[Tensor] = None
 | |
|     d_weight: Optional[Tensor] = None
 | |
| 
 | |
|     x_hat = input_cast * rstd
 | |
| 
 | |
|     if output_mask[0]:
 | |
|         sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True)
 | |
|         d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd
 | |
| 
 | |
|     if output_mask[1] and weight_cast is not None:
 | |
|         d_weight_full_shape = grad_out_cast * x_hat
 | |
|         if len(outer_dim_indices) > 0:
 | |
|             d_weight = torch.sum(
 | |
|                 d_weight_full_shape, dim=outer_dim_indices, keepdim=False
 | |
|             )
 | |
|         else:
 | |
|             d_weight = d_weight_full_shape
 | |
| 
 | |
|     return (
 | |
|         _maybe_cast(d_input, input.dtype),
 | |
|         _maybe_cast(d_weight, input.dtype),
 | |
|     )
 | |
| 
 | |
| 
 | |
| def native_batch_norm_helper(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     training: bool,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
|     functional: bool,
 | |
| ) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
 | |
|     reduction_dims = [0] + list(range(2, input.dim()))
 | |
|     computation_dtype = utils.get_computation_dtype(input.dtype)
 | |
|     new_running_mean = running_mean
 | |
|     new_running_var = running_var
 | |
|     if training:
 | |
|         computation_dtype = utils.get_computation_dtype(input.dtype)
 | |
|         input_acc = input.to(dtype=computation_dtype)
 | |
|         biased_var, mean = torch.var_mean(
 | |
|             input_acc, dim=reduction_dims, correction=0, keepdim=True
 | |
|         )
 | |
|         rstd = torch.rsqrt(biased_var + eps)
 | |
| 
 | |
|         output = (input - mean) * rstd
 | |
| 
 | |
|         save_mean = torch.squeeze(mean, reduction_dims)
 | |
|         save_rstd = torch.squeeze(rstd, reduction_dims)
 | |
|         if running_mean is not None:
 | |
|             new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
 | |
|             if not functional:
 | |
|                 running_mean.copy_(new_running_mean)
 | |
|         if running_var is not None:
 | |
|             n = input.numel() / input.shape[1]
 | |
|             # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
 | |
|             # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
 | |
|             # numerics probably don't matter.
 | |
|             squeezed_var = torch.squeeze(biased_var, reduction_dims)
 | |
|             unbiased_var = squeezed_var * (n / (n - 1))
 | |
|             new_running_var = momentum * unbiased_var + (1 - momentum) * running_var
 | |
|             if not functional:
 | |
|                 running_var.copy_(new_running_var)
 | |
|     else:
 | |
|         assert running_mean is not None and running_var is not None
 | |
|         running_mean = running_mean.to(dtype=computation_dtype, copy=True)
 | |
|         new_running_mean = running_mean
 | |
|         running_var = running_var.to(dtype=computation_dtype, copy=True)
 | |
|         new_running_var = running_var
 | |
|         mean = running_mean
 | |
|         invstd = 1 / (torch.sqrt(running_var + eps))
 | |
|         # Very annoying inconsistency where CPU and CUDA give different shapes
 | |
|         if input.device.type != "cpu":
 | |
|             save_mean = running_mean
 | |
|             save_rstd = invstd
 | |
|         else:
 | |
|             save_mean = input.new_zeros((0,))
 | |
|             save_rstd = input.new_zeros((0,))
 | |
|         mean = _unsqueeze_to_dim(mean, input.dim() - 1)
 | |
|         invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
 | |
|         output = (input - mean) * invstd
 | |
| 
 | |
|     if weight is not None:
 | |
|         weight = weight.flatten()
 | |
|         weight = _unsqueeze_to_dim(weight, input.dim() - 1)
 | |
|         output = output * weight
 | |
| 
 | |
|     if bias is not None:
 | |
|         bias = bias.flatten()
 | |
|         bias = _unsqueeze_to_dim(bias, input.dim() - 1)
 | |
|         output = output + bias
 | |
| 
 | |
|     if input.device.type == "cpu":
 | |
|         save_mean = save_mean.to(dtype=input.dtype)
 | |
|         save_rstd = save_rstd.to(dtype=input.dtype)
 | |
|     return (
 | |
|         output.to(dtype=input.dtype),
 | |
|         save_mean,
 | |
|         save_rstd,
 | |
|         new_running_mean,
 | |
|         new_running_var,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.native_batch_norm)
 | |
| @out_wrapper("out", "save_mean", "save_invstd")
 | |
| def native_batch_norm(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     training: bool,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor]:
 | |
|     output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
 | |
|         input, weight, bias, running_mean, running_var, training, momentum, eps, False
 | |
|     )
 | |
|     return output, save_mean, save_rstd
 | |
| 
 | |
| 
 | |
| # TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm
 | |
| # with our new correctly schema'd _native_batch_norm_legit and its variants, but
 | |
| # we cannot do that immediately in the C++ because it would be forwards incompatible
 | |
| # with some mobile use cases.
 | |
| #
 | |
| # Since this change is most impactful for aot autograd/functionalization, we simply
 | |
| # register this decomposition on the Autograd key for the python dispatcher (which is
 | |
| # currently only used by aot autograd/functionalization and no one else, really).
 | |
| # In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm
 | |
| # to be _native_batch_norm_legit and have the right schema (stating that there are input mutations).
 | |
| @aten.native_batch_norm.default.py_impl(DispatchKey.Autograd)
 | |
| @aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| def native_batch_norm_decomposition(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     training: bool,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor]:
 | |
|     if running_mean is None and running_var is None:
 | |
|         return aten._native_batch_norm_legit(
 | |
|             input, weight, bias, training, momentum, eps
 | |
|         )
 | |
|     if running_mean is None:
 | |
|         raise RuntimeError(
 | |
|             "running_mean is None, but running_var is provided. "
 | |
|             "They should both be None or both be provided."
 | |
|         )
 | |
|     if running_var is None:
 | |
|         raise RuntimeError(
 | |
|             "running_var is None, but running_mean is provided. "
 | |
|             "They should both be None or both be provided."
 | |
|         )
 | |
|     if training:
 | |
|         # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg.
 | |
|         return aten._native_batch_norm_legit(
 | |
|             input, weight, bias, running_mean, running_var, training, momentum, eps
 | |
|         )
 | |
|     else:
 | |
|         return aten._native_batch_norm_legit_no_training(
 | |
|             input, weight, bias, running_mean, running_var, momentum, eps
 | |
|         )
 | |
| 
 | |
| 
 | |
| @aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> list[Tensor]:
 | |
|     dim_size = tensor.size(dim)
 | |
|     split_size = (dim_size + chunks - 1) // chunks
 | |
| 
 | |
|     if split_size == 0 and dim_size == 0:
 | |
|         split_sizes = [split_size for _ in chunks]
 | |
|         split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
 | |
|         return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim)
 | |
|     return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._native_batch_norm_legit_no_training.default)
 | |
| def _native_batch_norm_legit_no_training(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Tensor,
 | |
|     running_var: Tensor,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor]:
 | |
|     return aten._native_batch_norm_legit.default(
 | |
|         input,
 | |
|         weight,
 | |
|         bias,
 | |
|         running_mean,
 | |
|         running_var,
 | |
|         False,  # training
 | |
|         momentum,
 | |
|         eps,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._native_batch_norm_legit.default)
 | |
| def _native_batch_norm_legit(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Tensor,
 | |
|     running_var: Tensor,
 | |
|     training: bool,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor]:
 | |
|     output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
 | |
|         input, weight, bias, running_mean, running_var, training, momentum, eps, False
 | |
|     )
 | |
|     return output, save_mean, save_rstd
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._native_batch_norm_legit.no_stats)
 | |
| def _native_batch_norm_legit_no_stats(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     training: bool,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor]:
 | |
|     output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
 | |
|         input, weight, bias, None, None, training, momentum, eps, False
 | |
|     )
 | |
|     return output, save_mean, save_rstd
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._native_batch_norm_legit_functional.default)
 | |
| def _native_batch_norm_legit_functional(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Tensor,
 | |
|     running_var: Tensor,
 | |
|     training: bool,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
 | |
|     (
 | |
|         output,
 | |
|         save_mean,
 | |
|         save_rstd,
 | |
|         new_running_mean,
 | |
|         new_running_var,
 | |
|     ) = native_batch_norm_helper(
 | |
|         input, weight, bias, running_mean, running_var, training, momentum, eps, True
 | |
|     )
 | |
|     assert new_running_mean is not None, "new_running_mean should not be None"
 | |
|     assert new_running_var is not None, "new_running_var should not be None"
 | |
|     return output, save_mean, save_rstd, new_running_mean, new_running_var
 | |
| 
 | |
| 
 | |
| def _get_batch_norm_reserve_tensor(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Tensor,
 | |
|     running_var: Tensor,
 | |
|     eps: float,
 | |
|     training: bool,
 | |
| ) -> Tensor:
 | |
|     """
 | |
|     Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the
 | |
|     backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`,
 | |
|     which support a variety of backends including cudnn. We create this tensor here to get
 | |
|     the correct shape in the traced graph if we detect that will call the cudnn kernel,
 | |
|     and rely on DCE to avoid materializing this tensor.
 | |
|     """
 | |
|     backend = torch._C._select_batch_norm_backend(  # type: ignore[attr-defined]
 | |
|         input, weight, bias, running_mean, running_var, True, eps
 | |
|     )
 | |
|     reserve_size = 0
 | |
|     if backend == torch._C._BatchNormBackend.Cudnn:  # type: ignore[attr-defined]
 | |
|         reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size(  # type: ignore[attr-defined]
 | |
|             input, training
 | |
|         )
 | |
|     return torch.empty(
 | |
|         reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._batch_norm_with_update.default)
 | |
| def _batch_norm_with_update(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Tensor,
 | |
|     running_var: Tensor,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
 | |
|     output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
 | |
|         input,
 | |
|         weight,
 | |
|         bias,
 | |
|         running_mean,
 | |
|         running_var,
 | |
|         True,  # training
 | |
|         momentum,
 | |
|         eps,
 | |
|         False,  # functional
 | |
|     )
 | |
|     reserve = _get_batch_norm_reserve_tensor(
 | |
|         input, weight, bias, running_mean, running_var, eps, training=True
 | |
|     )
 | |
|     return output, save_mean, save_rstd, reserve
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._batch_norm_with_update_functional.default)
 | |
| def _batch_norm_with_update_functional(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Tensor,
 | |
|     running_var: Tensor,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
 | |
|     (
 | |
|         output,
 | |
|         save_mean,
 | |
|         save_rstd,
 | |
|         new_rm,
 | |
|         new_rv,
 | |
|     ) = native_batch_norm_helper(
 | |
|         input, weight, bias, running_mean, running_var, True, momentum, eps, True
 | |
|     )
 | |
|     reserve = _get_batch_norm_reserve_tensor(
 | |
|         input, weight, bias, running_mean, running_var, eps, training=True
 | |
|     )
 | |
|     assert new_rm is not None, "new_running_mean should not be None"
 | |
|     assert new_rv is not None, "new_running_var should not be None"
 | |
|     return (output, save_mean, save_rstd, reserve, new_rm, new_rv)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._batch_norm_no_update.default)
 | |
| def _batch_norm_no_update(
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Tensor,
 | |
|     running_var: Tensor,
 | |
|     momentum: float,
 | |
|     eps: float,
 | |
| ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
 | |
|     output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
 | |
|         input,
 | |
|         weight,
 | |
|         bias,
 | |
|         running_mean,
 | |
|         running_var,
 | |
|         False,  # training
 | |
|         momentum,
 | |
|         eps,
 | |
|         False,  # functional
 | |
|     )
 | |
|     reserve = _get_batch_norm_reserve_tensor(
 | |
|         input, weight, bias, running_mean, running_var, eps, training=False
 | |
|     )
 | |
|     return output, save_mean, save_rstd, reserve
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._fused_dropout)
 | |
| @out_wrapper("out0", "out1")
 | |
| @pw_cast_for_opmath
 | |
| def _fused_dropout_decomposition(input, p, generator=None):
 | |
|     assert generator is None
 | |
|     mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
 | |
|     res = mask.type_as(input) * input * (1.0 / p)
 | |
|     return (res, mask)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._to_copy)
 | |
| @out_wrapper()
 | |
| def _to_copy(
 | |
|     x: Union[Tensor, NumberType],
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout=None,
 | |
|     device: Optional[torch.device] = None,
 | |
|     pin_memory: bool = False,
 | |
|     non_blocking: bool = False,
 | |
|     memory_format: Optional[torch.memory_format] = None,
 | |
| ):
 | |
|     assert not layout or layout == torch.strided, "TODO"
 | |
|     assert not pin_memory, "TODO"
 | |
|     assert isinstance(x, (torch.Tensor, int, float, bool, complex))
 | |
|     if device is None and dtype is None and memory_format is None:
 | |
|         if isinstance(x, torch.Tensor):
 | |
|             return x.clone()
 | |
|         else:
 | |
|             return x
 | |
|     dtype_converted = False
 | |
| 
 | |
|     if isinstance(x, torch.Tensor):
 | |
|         x_tensor = x
 | |
|     else:
 | |
|         x_tensor = torch.scalar_tensor(x)
 | |
| 
 | |
|     if device is not None and device != x_tensor.device:
 | |
|         # avoid conversions on cpu
 | |
|         if dtype is not None and device.type == "cpu":
 | |
|             x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
 | |
|             dtype_converted = True
 | |
|         x_tensor = torch._prims.device_put(x_tensor, device, non_blocking)
 | |
| 
 | |
|     if dtype is not None and not dtype_converted:
 | |
|         x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
 | |
|         dtype_converted = True
 | |
| 
 | |
|     if memory_format is not None:  # no ref/prim for memory format
 | |
|         return torch.clone(x_tensor, memory_format=memory_format)
 | |
|     return x_tensor
 | |
| 
 | |
| 
 | |
| # Questionable decompositions
 | |
| # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
 | |
| # Note that this decomposition causes issues with in-place ops
 | |
| @register_decomposition([aten.detach, aten.lift, aten.lift_fresh])
 | |
| @out_wrapper()
 | |
| def nop_decomposition(x):
 | |
|     return aten.alias(x)
 | |
| 
 | |
| 
 | |
| # Also register to the Autograd dispatch key, so this decomp can run above autograd.
 | |
| # native_batch_norm needs to decompose into other ops before autograd.
 | |
| @aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd)
 | |
| @register_decomposition(aten.cudnn_batch_norm)
 | |
| @out_wrapper("out0", "out1", "out2", "out3")
 | |
| def cudnn_batch_norm(
 | |
|     input: Tensor,
 | |
|     weight: Tensor,
 | |
|     bias: Optional[Tensor],
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     training: bool,
 | |
|     exponential_average_factor: float,
 | |
|     epsilon: float,
 | |
| ):
 | |
|     a, b, c = aten.native_batch_norm(
 | |
|         input,
 | |
|         weight,
 | |
|         bias,
 | |
|         running_mean,
 | |
|         running_var,
 | |
|         training,
 | |
|         exponential_average_factor,
 | |
|         epsilon,
 | |
|     )
 | |
|     # Cudnn return running mean and variance when training is True
 | |
|     if training:
 | |
|         return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
 | |
|     return (
 | |
|         a,
 | |
|         weight.new_zeros((0,)),
 | |
|         weight.new_zeros((0,)),
 | |
|         input.new_zeros((0,), dtype=torch.uint8),
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _broadcast_batch_norm_backward(x, broadcast_mask):
 | |
|     for axis, mask in enumerate(broadcast_mask):
 | |
|         if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask):
 | |
|             x = x.unsqueeze(axis)
 | |
|     return x
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.batch_norm_backward.default)
 | |
| def batch_norm_backward(
 | |
|     grad_out: Tensor,
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     save_mean: Optional[Tensor],
 | |
|     save_invstd: Optional[Tensor],
 | |
|     train: bool,
 | |
|     eps: float,
 | |
|     output_mask: list[bool],
 | |
|     reserve: Tensor,
 | |
| ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
 | |
|     return native_batch_norm_backward(
 | |
|         grad_out,
 | |
|         input,
 | |
|         weight,
 | |
|         running_mean,
 | |
|         running_var,
 | |
|         save_mean,
 | |
|         save_invstd,
 | |
|         train,
 | |
|         eps,
 | |
|         output_mask,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.native_batch_norm_backward.default)
 | |
| def native_batch_norm_backward(
 | |
|     grad_out: Tensor,
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     save_mean: Optional[Tensor],
 | |
|     save_invstd: Optional[Tensor],
 | |
|     train: bool,
 | |
|     eps: float,
 | |
|     output_mask: list[bool],
 | |
| ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
 | |
|     input_dtype = input.dtype
 | |
|     if weight is not None:
 | |
|         weight_dtype = weight.dtype
 | |
|     else:
 | |
|         weight_dtype = input_dtype
 | |
|     computation_dtype = utils.get_computation_dtype(input.dtype)
 | |
|     (
 | |
|         grad_out_cast,
 | |
|         input_cast,
 | |
|         weight_cast,
 | |
|         running_mean_cast,
 | |
|         running_var_cast,
 | |
|         save_mean_cast,
 | |
|         save_invstd_cast,
 | |
|     ) = (
 | |
|         x.to(computation_dtype) if x is not None else x
 | |
|         for x in (
 | |
|             grad_out,
 | |
|             input,
 | |
|             weight,
 | |
|             running_mean,
 | |
|             running_var,
 | |
|             save_mean,
 | |
|             save_invstd,
 | |
|         )
 | |
|     )
 | |
|     input_shape = input.shape
 | |
|     input_rank = input.dim()
 | |
|     assert input_rank >= 2, "rank of the input must be at least 2"
 | |
| 
 | |
|     axis = 1
 | |
|     num_features = prod(list(input_shape)) / input_shape[axis]
 | |
|     mean = save_mean_cast
 | |
|     invstd = save_invstd_cast
 | |
|     if train:
 | |
|         assert mean is not None and invstd is not None
 | |
| 
 | |
|     else:
 | |
|         assert running_mean_cast is not None and running_var_cast is not None
 | |
|         mean = running_mean_cast
 | |
|         invstd = torch.rsqrt(running_var_cast + eps)
 | |
| 
 | |
|     broadcast_mask: list[int] = [1] * input_rank
 | |
|     broadcast_mask[axis] = input_shape[axis]
 | |
| 
 | |
|     reduction_axes: list[int] = []
 | |
|     for i in range(input_rank):
 | |
|         if i != axis:
 | |
|             reduction_axes.append(i)
 | |
| 
 | |
|     mean = _broadcast_batch_norm_backward(mean, broadcast_mask)  # type: ignore[arg-type]
 | |
|     norm = 1.0 / num_features
 | |
|     grad_output_sum = torch.sum(grad_out_cast, reduction_axes)  # type: ignore[arg-type]
 | |
|     dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes)  # type: ignore[operator]
 | |
| 
 | |
|     grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask)
 | |
|     proj_scale = _broadcast_batch_norm_backward(
 | |
|         torch.mul(dot_p * norm, invstd * invstd),  # type: ignore[operator]
 | |
|         broadcast_mask,
 | |
|     )
 | |
| 
 | |
|     if weight_cast is None:
 | |
|         grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0  # type: ignore[arg-type]
 | |
|     else:
 | |
|         grad_scale = _broadcast_batch_norm_backward(
 | |
|             invstd * weight_cast, broadcast_mask
 | |
|         )
 | |
| 
 | |
|     if train:
 | |
|         proj = (input_cast - mean) * proj_scale  # type: ignore[operator]
 | |
|         grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale
 | |
|     else:
 | |
|         grad_input = grad_out_cast * grad_scale
 | |
| 
 | |
|     if output_mask[1]:
 | |
|         grad_weight = dot_p * invstd
 | |
|     else:
 | |
|         grad_weight = None  # "None" doesn't work with vjp, should use zeros for vjp
 | |
| 
 | |
|     if output_mask[2]:
 | |
|         grad_bias = grad_output_sum
 | |
|     else:
 | |
|         grad_bias = None  # "None" doesn't work with vjp, should use zeros for vjp
 | |
| 
 | |
|     return (
 | |
|         grad_input.to(input_dtype),
 | |
|         _maybe_cast(grad_weight, weight_dtype),
 | |
|         _maybe_cast(grad_bias, weight_dtype),
 | |
|     )
 | |
| 
 | |
| 
 | |
| # out_wrapper currently does not allow optional outputs
 | |
| @register_decomposition(aten.native_batch_norm_backward.out)
 | |
| def native_batch_norm_backward_out(
 | |
|     grad_out: Tensor,
 | |
|     input: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     save_mean: Optional[Tensor],
 | |
|     save_invstd: Optional[Tensor],
 | |
|     train: bool,
 | |
|     eps: float,
 | |
|     output_mask: list[bool],
 | |
|     *,
 | |
|     out0: torch.Tensor,
 | |
|     out1: torch.Tensor,
 | |
|     out2: torch.Tensor,
 | |
| ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
 | |
|     result = native_batch_norm_backward(
 | |
|         grad_out,
 | |
|         input,
 | |
|         weight,
 | |
|         running_mean,
 | |
|         running_var,
 | |
|         save_mean,
 | |
|         save_invstd,
 | |
|         train,
 | |
|         eps,
 | |
|         output_mask,
 | |
|     )
 | |
|     grad_input = (out0, out1, out2)
 | |
|     for i, r in enumerate(result):
 | |
|         if r is not None:
 | |
|             _maybe_resize_out(grad_input[i], r.shape)
 | |
|             _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
 | |
| 
 | |
|     return grad_input
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.miopen_batch_norm_backward)
 | |
| @out_wrapper("out0", "out1", "out2")
 | |
| def miopen_batch_norm_backward(
 | |
|     input: Tensor,
 | |
|     grad_output: Tensor,
 | |
|     weight: Tensor,
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     save_mean: Optional[Tensor],
 | |
|     save_var: Optional[Tensor],
 | |
|     epsilon: float,
 | |
| ):
 | |
|     return aten.native_batch_norm_backward(
 | |
|         grad_output,
 | |
|         input,
 | |
|         weight,
 | |
|         running_mean,
 | |
|         running_var,
 | |
|         save_mean,
 | |
|         save_var,
 | |
|         True,
 | |
|         epsilon,
 | |
|         [True, True, True],
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.cudnn_batch_norm_backward)
 | |
| @out_wrapper("out0", "out1", "out2")
 | |
| def cudnn_batch_norm_backward(
 | |
|     input: Tensor,
 | |
|     grad_output: Tensor,
 | |
|     weight: Tensor,
 | |
|     running_mean: Optional[Tensor],
 | |
|     running_var: Optional[Tensor],
 | |
|     save_mean: Optional[Tensor],
 | |
|     save_var: Optional[Tensor],
 | |
|     epsilon: float,
 | |
|     reserveSpace: Tensor,
 | |
| ):
 | |
|     return aten.native_batch_norm_backward(
 | |
|         grad_output,
 | |
|         input,
 | |
|         weight,
 | |
|         running_mean,
 | |
|         running_var,
 | |
|         save_mean,
 | |
|         save_var,
 | |
|         True,
 | |
|         epsilon,
 | |
|         [True, True, True],
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._adaptive_avg_pool2d)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def adaptive_avg_pool2d(input: Tensor, output_size: tuple[int, int]):
 | |
|     # Preconditions
 | |
|     device = input.device
 | |
|     shape = input.shape
 | |
|     ndim = len(shape)
 | |
|     torch._check(
 | |
|         ndim in (3, 4),
 | |
|         lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
 | |
|     )
 | |
|     for d in input.shape[-2:]:
 | |
|         torch._check(
 | |
|             d != 0,
 | |
|             lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
 | |
|             f"non-batch dimensions, but input has shape {tuple(shape)}.",
 | |
|         )
 | |
| 
 | |
|     # Optimisation (we should also do this in the kernel implementation)
 | |
|     if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
 | |
|         stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
 | |
|         kernel = tuple(
 | |
|             i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
 | |
|         )
 | |
|         return torch.nn.functional.avg_pool2d(input, kernel, stride)
 | |
| 
 | |
|     def start_index(a, b, c):
 | |
|         return torch.div(a * c, b, rounding_mode="trunc")
 | |
| 
 | |
|     def end_index(a, b, c):
 | |
|         return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc")
 | |
| 
 | |
|     def compute_idx(in_size, out_size):
 | |
|         orange = torch.arange(out_size, device=device, dtype=torch.int64)
 | |
|         i0 = start_index(orange, out_size, in_size)
 | |
|         # Let length = end_index - start_index, i.e. the length of the pooling kernels
 | |
|         # length.max() can be computed analytically as follows:
 | |
|         maxlength = in_size // out_size + 1
 | |
|         in_size_mod = in_size % out_size
 | |
|         # adaptive = True iff there are kernels with different lengths
 | |
|         adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
 | |
|         if adaptive:
 | |
|             maxlength += 1
 | |
|         elif in_size_mod == 0:
 | |
|             maxlength -= 1
 | |
| 
 | |
|         range_max = torch.arange(maxlength, device=device, dtype=torch.int64)
 | |
|         idx = i0.unsqueeze(-1) + range_max
 | |
|         if adaptive:
 | |
|             # Need to clamp to avoid accessing out-of-bounds memory
 | |
|             # TODO make minimum accept scalars
 | |
|             maxval = torch.scalar_tensor(
 | |
|                 in_size - 1, dtype=idx.dtype, device=idx.device
 | |
|             )
 | |
|             idx = torch.minimum(idx, maxval)
 | |
| 
 | |
|             # Compute the length
 | |
|             i1 = end_index(orange, out_size, in_size)
 | |
|             length = i1 - i0
 | |
|         else:
 | |
|             length = maxlength
 | |
|         return idx, length, range_max, adaptive
 | |
| 
 | |
|     # length is not None if it's constant, otherwise we'll need to compute it
 | |
|     idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2])
 | |
|     idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1])
 | |
| 
 | |
|     vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw]
 | |
|     # Shortcut for the simpler case
 | |
|     if not adaptive_h and not adaptive_w:
 | |
|         return torch.mean(vals, dim=(-3, -1))
 | |
| 
 | |
|     def maybe_mask(vals, length, range_max, adaptive, dim):
 | |
|         if isinstance(length, IntLike):
 | |
|             return vals, length
 | |
|         else:
 | |
|             # zero-out the things we didn't really want to select
 | |
|             assert dim < 0
 | |
|             # hack
 | |
|             mask = range_max >= length.unsqueeze(-1)
 | |
|             if dim == -2:
 | |
|                 mask = _unsqueeze_to_dim(mask, 4)
 | |
|             vals = torch.masked_fill(vals, mask, 0.0)
 | |
|             # Compute the length of each window
 | |
|             length = _unsqueeze_to_dim(length, -dim)
 | |
|             return vals, length
 | |
| 
 | |
|     vals, length_h = maybe_mask(
 | |
|         vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2
 | |
|     )
 | |
|     vals, length_w = maybe_mask(
 | |
|         vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1
 | |
|     )
 | |
| 
 | |
|     # We unroll the sum as we assume that the kernels are going to be small
 | |
|     ret = None
 | |
|     for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])):
 | |
|         if ret is None:
 | |
|             ret = vals[..., i, :, j]
 | |
|         else:
 | |
|             ret = ret + vals[..., i, :, j]
 | |
|     return ret / (length_h * length_w)
 | |
| 
 | |
| 
 | |
| def _max_unpoolnd(
 | |
|     self: TensorLike, indices: TensorLike, output_size: list[int], dim: int
 | |
| ):
 | |
|     # If the input tensors self and indices came from max_pool call as
 | |
|     # required by the documentation, this operation is deterministic
 | |
|     # because that ensures that if there are two entries in `indices`
 | |
|     # tensor that are equal, the corresponding values in `self` are also
 | |
|     # equal. If this condition is not satisfied, the operation is
 | |
|     # non-deterministic as one of the different values in `self` 'wins'.
 | |
|     utils.alert_not_deterministic(f"max_unpooling{dim}d_forward_out")
 | |
|     nc = reduce(operator.mul, self.shape[:-dim])
 | |
|     hw = reduce(operator.mul, output_size)
 | |
|     indices_nc_shape = [1] * self.ndim
 | |
|     indices_nc_shape[:-dim] = self.shape[:-dim]
 | |
|     indices_flat = (
 | |
|         indices + aten.arange(nc, device=self.device).view(indices_nc_shape) * hw
 | |
|     ).reshape(-1)
 | |
| 
 | |
|     output = self.new_zeros(list(self.shape[:-dim]) + list(output_size))
 | |
|     return aten._unsafe_index_put(
 | |
|         output.reshape(-1), [indices_flat], self.reshape(-1), accumulate=False
 | |
|     ).view(output.shape)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.max_unpool2d)
 | |
| @out_wrapper()
 | |
| def max_unpool2d(
 | |
|     self: TensorLike,
 | |
|     indices: TensorLike,
 | |
|     output_size: list[int],
 | |
| ):
 | |
|     torch._check(
 | |
|         indices.dtype == torch.int64,
 | |
|         lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
 | |
|     )
 | |
|     torch._check(
 | |
|         len(output_size) == 2,
 | |
|         lambda: (
 | |
|             f"There should be exactly two elements (height, width) in output_size, "
 | |
|             f"but got {len(output_size)} elements."
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     torch._check(
 | |
|         self.ndim in (3, 4),
 | |
|         lambda: (
 | |
|             f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
 | |
|             f"but got a tensor with {self.ndim} dimensions."
 | |
|         ),
 | |
|     )
 | |
|     torch._check(
 | |
|         self.shape == indices.shape,
 | |
|         lambda: (
 | |
|             f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
 | |
|             f"but got indices tensor with shape: {indices.shape}"
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     for i in range(1, self.ndim):
 | |
|         torch._check(
 | |
|             self.size(i) > 0,
 | |
|             lambda: (
 | |
|                 f"max_unpooling2d(): "
 | |
|                 f"Expected input to have non-zero size for non-batch dimensions, "
 | |
|                 f"but got {self.shape} with dimension {i} being empty."
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     return _max_unpoolnd(self, indices, output_size, 2)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.max_unpool3d)
 | |
| @out_wrapper()
 | |
| def max_unpool3d(
 | |
|     input: TensorLike,
 | |
|     indices: TensorLike,
 | |
|     output_size: list[int],
 | |
|     stride: list[int],
 | |
|     padding: list[int],
 | |
| ):
 | |
|     torch._check(
 | |
|         indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
 | |
|     )
 | |
|     torch._check(
 | |
|         input.ndim in (4, 5),
 | |
|         lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
 | |
|     )
 | |
|     torch._check(
 | |
|         len(output_size) == 3,
 | |
|         lambda: (
 | |
|             f"There should be exactly three elements (depth, height, width) in output_size, "
 | |
|             f"but got {len(output_size)} elements."
 | |
|         ),
 | |
|     )
 | |
|     torch._check(
 | |
|         len(stride) == 3,
 | |
|         lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
 | |
|     )
 | |
|     torch._check(
 | |
|         len(padding) == 3,
 | |
|         lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
 | |
|     )
 | |
|     torch._check(
 | |
|         input.shape == indices.shape,
 | |
|         lambda: (
 | |
|             f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
 | |
|             f"but got indices tensor with shape: {indices.shape}"
 | |
|         ),
 | |
|     )
 | |
| 
 | |
|     for i in range(1, input.ndim):
 | |
|         torch._check(
 | |
|             input.size(i) > 0,
 | |
|             lambda: (
 | |
|                 f"max_unpooling3d(): "
 | |
|                 f"Expected input to have non-zero size for non-batch dimensions, "
 | |
|                 f"but got {input.shape} with dimension {i} being empty."
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|     torch._check(
 | |
|         stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
 | |
|         lambda: f"strides should be greater than zero, but got stride: {stride}",
 | |
|     )
 | |
| 
 | |
|     return _max_unpoolnd(input, indices, output_size, 3)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.index_add_)
 | |
| def index_add_(
 | |
|     x: TensorLike,
 | |
|     dim: int,
 | |
|     index: TensorLike,
 | |
|     tensor: TensorLike,
 | |
|     *,
 | |
|     alpha: NumberType = 1,
 | |
| ):
 | |
|     return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.index_add)
 | |
| @out_wrapper()
 | |
| def index_add(
 | |
|     x: TensorLike,
 | |
|     dim: int,
 | |
|     index: TensorLike,
 | |
|     tensor: TensorLike,
 | |
|     *,
 | |
|     alpha: NumberType = 1,
 | |
| ):
 | |
|     return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
 | |
| 
 | |
| 
 | |
| def _index_add(
 | |
|     x: TensorLike,
 | |
|     dim: int,
 | |
|     index: TensorLike,
 | |
|     tensor: TensorLike,
 | |
|     *,
 | |
|     inplace: bool,
 | |
|     alpha: NumberType = 1,
 | |
| ):
 | |
|     dim = utils.canonicalize_dims(x.ndim, dim)
 | |
|     torch._check(
 | |
|         index.ndim <= 1,
 | |
|         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
 | |
|     )
 | |
|     index_size = index.size(0) if index.ndim == 1 else 1
 | |
|     tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1
 | |
|     torch._check(
 | |
|         tensor_size == index_size,
 | |
|         lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}",
 | |
|     )
 | |
|     if alpha != 1:
 | |
|         python_type = utils.dtype_to_type(x.dtype)
 | |
|         torch._check(
 | |
|             python_type == bool
 | |
|             or utils.is_weakly_lesser_type(type(alpha), python_type),
 | |
|             lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
 | |
|         )
 | |
|         tensor = tensor * alpha
 | |
|     # Treat scalars as elements of \R^1
 | |
|     zero_dim = x.ndim == 0
 | |
|     x1 = x.unsqueeze(0) if zero_dim else x
 | |
|     idx = (None,) * dim + (index,)
 | |
|     index_put = aten.index_put_ if inplace else aten.index_put
 | |
|     out = index_put(x1, idx, tensor, accumulate=True)
 | |
|     if inplace:
 | |
|         return x
 | |
|     else:
 | |
|         return out.squeeze(0) if zero_dim else out.contiguous()
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.pad_sequence.default)
 | |
| @aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| def pad_sequence(sequences, batch_first=False, padding_value=0.0):
 | |
|     torch._check(len(sequences) > 0, lambda: "received an empty list of sequences")
 | |
|     sequences_size = len(sequences)
 | |
|     max_size = sequences[0].size()
 | |
|     trailing_dims = max_size[1:]
 | |
|     max_len = max(x.size(0) for x in sequences)
 | |
|     if batch_first:
 | |
|         out_dims = (sequences_size, max_len)
 | |
|     else:
 | |
|         out_dims = (max_len, sequences_size)
 | |
|     out_dims = out_dims + trailing_dims
 | |
|     out = sequences[0].new_full(out_dims, padding_value)
 | |
|     dim_paddings = (0, 0) * len(trailing_dims)
 | |
|     for i in range(sequences_size):
 | |
|         currseq = sequences[i]
 | |
|         row = aten.constant_pad_nd(
 | |
|             currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value
 | |
|         )
 | |
|         if batch_first:
 | |
|             out = aten.select_scatter(out, row, dim=0, index=i)
 | |
|         else:
 | |
|             out = aten.select_scatter(out, row, dim=1, index=i)
 | |
|     return out
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.index_copy_)
 | |
| def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
 | |
|     return _index_copy(x, dim, index, tensor, inplace=True)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.index_copy)
 | |
| @out_wrapper()
 | |
| def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
 | |
|     return _index_copy(x, dim, index, tensor, inplace=False)
 | |
| 
 | |
| 
 | |
| def _index_copy(
 | |
|     x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
 | |
| ):
 | |
|     dim = utils.canonicalize_dims(x.ndim, dim)
 | |
|     torch._check(
 | |
|         index.ndim <= 1,
 | |
|         lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
 | |
|     )
 | |
|     # Treat scalars as elements of \R^1
 | |
|     zero_dim = x.ndim == 0
 | |
|     x1 = x.unsqueeze(0) if zero_dim else x
 | |
|     index = index.unsqueeze(0) if index.ndim == 0 else index
 | |
|     idx = (None,) * dim + (index,)
 | |
|     index_put = aten.index_put_ if inplace else aten.index_put
 | |
|     out = index_put(x1, idx, tensor)
 | |
|     if inplace:
 | |
|         return x
 | |
|     else:
 | |
|         return out.squeeze(0) if zero_dim else out.contiguous()
 | |
| 
 | |
| 
 | |
| # nb: Should use acc_t, not op_math
 | |
| @register_decomposition(aten.log_sigmoid_forward)
 | |
| @out_wrapper("output", "buffer")
 | |
| @pw_cast_for_opmath
 | |
| def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
 | |
|     min = torch.minimum(self.new_zeros(()), self)
 | |
|     z = torch.exp(-torch.abs(self))
 | |
|     if self.is_cuda or self.is_xpu:
 | |
|         buffer = self.new_zeros((0,))
 | |
|     else:
 | |
|         buffer = z
 | |
|     return min - torch.log1p(z), buffer
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.uniform)
 | |
| @out_wrapper()
 | |
| def uniform(
 | |
|     x: Tensor,
 | |
|     low: Union[bool, int, float] = 0.0,
 | |
|     high: Union[bool, int, float] = 1.0,
 | |
|     generator: Optional[torch.Generator] = None,
 | |
| ):
 | |
|     return prims._uniform_helper(
 | |
|         x.shape,
 | |
|         low=sym_float(low),
 | |
|         high=sym_float(high),
 | |
|         dtype=x.dtype,
 | |
|         device=x.device,
 | |
|         generator=generator,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.uniform_)
 | |
| def uniform_(self, low=0, high=1, generator=None):
 | |
|     return self.copy_(uniform(self, low, high, generator))
 | |
| 
 | |
| 
 | |
| # aten/src/ATen/native/UpSample.cpp compute_output_size
 | |
| def upsample_compute_output_size(input_size, output_size, scale_factors):
 | |
|     spatial_dimensions = len(input_size) - 2
 | |
|     if output_size is not None:
 | |
|         torch._check(
 | |
|             scale_factors is None,
 | |
|             lambda: "Must specify exactly one of output_size and scale_factors",
 | |
|         )
 | |
|         torch._check(len(output_size) == spatial_dimensions, lambda: "")
 | |
|         return output_size
 | |
|     if scale_factors is not None:
 | |
|         # NB: this isn't necessary lol
 | |
|         torch._check(
 | |
|             output_size is None,
 | |
|             lambda: "Must specify exactly one of output_size and scale_factors",
 | |
|         )
 | |
|         torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
 | |
|         output_size = []
 | |
|         for i, s in enumerate(scale_factors):
 | |
|             if int(s) == s:
 | |
|                 output_size.append(input_size[i + 2] * int(s))
 | |
|             else:
 | |
|                 output_size.append(sym_int(input_size[i + 2] * s))
 | |
|         return output_size
 | |
|     torch._check(
 | |
|         False, lambda: "Must specify exactly one of output_size and scale_factors"
 | |
|     )
 | |
| 
 | |
| 
 | |
| def get_scale_value(scales, idx):
 | |
|     if scales is None:
 | |
|         return None
 | |
|     return scales[idx]
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.upsample_nearest1d.vec)
 | |
| @register_decomposition(aten.upsample_nearest2d.vec)
 | |
| @register_decomposition(aten.upsample_nearest3d.vec)
 | |
| @aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
 | |
| @aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
 | |
| @aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
 | |
| def _upsample_nearest_vec(
 | |
|     input: Tensor,
 | |
|     output_size: Optional[list[int]],
 | |
|     scale_factors: Optional[list[float]],
 | |
| ) -> Tensor:
 | |
|     osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
 | |
|     scales = (
 | |
|         scale_factors if scale_factors else [None] * len(osize)  # type: ignore[list-item]
 | |
|     )
 | |
|     return _upsample_nearest(input, osize, scales)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._upsample_nearest_exact1d.vec)
 | |
| @register_decomposition(aten._upsample_nearest_exact2d.vec)
 | |
| @register_decomposition(aten._upsample_nearest_exact3d.vec)
 | |
| @aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
 | |
| @aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
 | |
| @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
 | |
| def _upsample_nearest_exact_vec(
 | |
|     input: Tensor,
 | |
|     output_size: Optional[list[int]],
 | |
|     scale_factors: Optional[list[float]],
 | |
| ) -> Tensor:
 | |
|     osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
 | |
|     scales = (
 | |
|         scale_factors if scale_factors else [None] * len(osize)  # type: ignore[list-item]
 | |
|     )
 | |
|     return _upsample_nearest(input, osize, scales, exact=True)
 | |
| 
 | |
| 
 | |
| def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
 | |
|     # For each dim in output_size, compute the set of input indices used
 | |
|     # to produce the upsampled output.
 | |
|     indices = []
 | |
|     num_spatial_dims = len(output_size)
 | |
|     offset = 0.5 if exact else 0.0
 | |
| 
 | |
|     for d in range(num_spatial_dims):
 | |
|         # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp
 | |
|         #
 | |
|         # Indices are computed as following:
 | |
|         # scale = isize / osize
 | |
|         # Case: exact=False
 | |
|         # input_index = floor(output_index * scale)
 | |
|         # Same as OpenCV INTER_NEAREST
 | |
|         #
 | |
|         # Case: exact=False
 | |
|         # index_f32 = (output_index + 0.5) * scale - 0.5
 | |
|         # input_index = round(index_f32)
 | |
|         # Same as Pillow and Scikit-Image/Scipy ndi.zoom
 | |
|         osize = output_size[d]
 | |
|         isize = input.shape[-num_spatial_dims + d]
 | |
|         scale = isize / (isize * scales[d]) if scales[d] is not None else isize / osize
 | |
| 
 | |
|         output_indices = torch.arange(osize, dtype=torch.float32, device=input.device)
 | |
|         input_indices = ((output_indices + offset) * scale).to(torch.int64)
 | |
|         for _ in range(num_spatial_dims - 1 - d):
 | |
|             input_indices = input_indices.unsqueeze(-1)
 | |
|         indices.append(input_indices)
 | |
|     return indices
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out])
 | |
| @aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper(preserve_memory_format=True, exact_dtype=True)
 | |
| def upsample_nearest1d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     scales: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_nearest(input, output_size, [scales])
 | |
| 
 | |
| 
 | |
| @register_decomposition(
 | |
|     [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out]
 | |
| )
 | |
| @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper(preserve_memory_format=True, exact_dtype=True)
 | |
| def upsample_nearest_exact1d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     scales: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_nearest(input, output_size, [scales], exact=True)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
 | |
| @aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper(preserve_memory_format=True, exact_dtype=True)
 | |
| def upsample_nearest2d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     scales_h: Optional[float] = None,
 | |
|     scales_w: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_nearest(input, output_size, [scales_h, scales_w])
 | |
| 
 | |
| 
 | |
| @register_decomposition(
 | |
|     [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out]
 | |
| )
 | |
| @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper(preserve_memory_format=True, exact_dtype=True)
 | |
| def _upsample_nearest_exact2d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     scales_h: Optional[float] = None,
 | |
|     scales_w: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out])
 | |
| @aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper(preserve_memory_format=True, exact_dtype=True)
 | |
| def upsample_nearest3d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     scales_d: Optional[float] = None,
 | |
|     scales_h: Optional[float] = None,
 | |
|     scales_w: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
 | |
| 
 | |
| 
 | |
| @register_decomposition(
 | |
|     [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out]
 | |
| )
 | |
| @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper(preserve_memory_format=True, exact_dtype=True)
 | |
| def _upsample_nearest_exact3d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     scales_d: Optional[float] = None,
 | |
|     scales_h: Optional[float] = None,
 | |
|     scales_w: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_nearest(
 | |
|         input, output_size, [scales_d, scales_h, scales_w], exact=True
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pw_cast_for_opmath
 | |
| def _upsample_nearest(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     scales: list[Optional[float]],
 | |
|     exact: bool = False,
 | |
| ) -> Tensor:
 | |
|     spatial_indices = _compute_upsample_nearest_indices(
 | |
|         input, output_size, scales, exact=exact
 | |
|     )
 | |
| 
 | |
|     indices = [None, None] + spatial_indices
 | |
|     result = aten._unsafe_index(input, indices)
 | |
| 
 | |
|     if result.ndim == 4:
 | |
|         # convert output to correct memory format, if necessary
 | |
|         memory_format = utils.suggest_memory_format(input)
 | |
| 
 | |
|         # following "heuristic: only use channels_last path when it's faster than the contiguous path"
 | |
|         n_channels = input.shape[1]
 | |
|         if input.device.type == "cuda" and n_channels < 4:
 | |
|             memory_format = torch.contiguous_format
 | |
| 
 | |
|         result = result.contiguous(memory_format=memory_format)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def gather_params(params, has_biases, has_projections):
 | |
|     if has_biases and has_projections:
 | |
|         group_size = 5
 | |
|     elif has_biases:
 | |
|         group_size = 4
 | |
|     elif has_projections:
 | |
|         group_size = 3
 | |
|     else:
 | |
|         group_size = 2
 | |
| 
 | |
|     assert len(params) % group_size == 0, len(params)
 | |
|     return [
 | |
|         tuple(params[i : i + group_size]) for i in range(0, len(params), group_size)
 | |
|     ]
 | |
| 
 | |
| 
 | |
| def params_hiddens(params, hiddens, i, bidirectional):
 | |
|     if bidirectional:
 | |
|         cur_params, cur_hidden = params[2 * i], hiddens[2 * i]
 | |
|         bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1]
 | |
|     else:
 | |
|         cur_params, cur_hidden = params[i], hiddens[i]
 | |
|         bidir_params, bidir_hidden = None, None
 | |
| 
 | |
|     return cur_params, cur_hidden, bidir_params, bidir_hidden
 | |
| 
 | |
| 
 | |
| def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens):
 | |
|     assert last_batch_size > batch_size
 | |
|     hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size))
 | |
|     return cur_hidden.narrow(0, 0, batch_size)
 | |
| 
 | |
| 
 | |
| def update_hidden_for_packed_reverse(
 | |
|     cur_hidden, last_batch_size, batch_size, inp_hidden
 | |
| ):
 | |
|     if last_batch_size == batch_size:
 | |
|         return cur_hidden
 | |
|     assert last_batch_size < batch_size
 | |
|     return torch.concat(
 | |
|         (
 | |
|             cur_hidden,
 | |
|             inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size),
 | |
|         )
 | |
|     )
 | |
| 
 | |
| 
 | |
| def one_layer_rnn_data(
 | |
|     inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False
 | |
| ):
 | |
|     ih_weight = params[0]
 | |
|     hh_weight = params[1]
 | |
|     ih_bias = params[2] if has_biases else None
 | |
|     hh_bias = params[3] if has_biases else None
 | |
| 
 | |
|     step_output = []
 | |
|     hiddens: list[torch.Tensor] = []
 | |
| 
 | |
|     last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
 | |
|     cur_hidden = hidden.narrow(0, 0, last_batch_size)
 | |
|     split_inp = torch.split(inp, list(batch_sizes))
 | |
|     if reverse:
 | |
|         split_inp = split_inp[::-1]
 | |
|     for inp in split_inp:
 | |
|         i = inp.shape[0]
 | |
| 
 | |
|         if last_batch_size == i:
 | |
|             pass  # don't update cur_hidden
 | |
|         # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
 | |
|         elif reverse:
 | |
|             cur_hidden = update_hidden_for_packed_reverse(
 | |
|                 cur_hidden, last_batch_size, i, hidden
 | |
|             )
 | |
|         else:
 | |
|             cur_hidden = update_hidden_for_packed(
 | |
|                 cur_hidden, last_batch_size, i, hiddens
 | |
|             )
 | |
| 
 | |
|         cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
 | |
|         last_batch_size = i
 | |
|         step_output.append(cur_hidden)
 | |
| 
 | |
|     if reverse:
 | |
|         step_output.reverse()
 | |
|     else:
 | |
|         hiddens.append(cur_hidden)
 | |
|         hiddens.reverse()
 | |
| 
 | |
|     out = torch.cat(step_output, 0)
 | |
|     hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden
 | |
|     return out, hidden_out
 | |
| 
 | |
| 
 | |
| def rnn_cell(nonlinearity):
 | |
|     def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
 | |
|         return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
 | |
| 
 | |
|     return inner
 | |
| 
 | |
| 
 | |
| def rnn_cell_data(nonlinearity):
 | |
|     def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
 | |
|         i = F.linear(i, ih_weight, ih_bias)
 | |
|         return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
 | |
| 
 | |
|     return inner
 | |
| 
 | |
| 
 | |
| def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False):
 | |
|     ih_weight = params[0]
 | |
|     hh_weight = params[1]
 | |
|     ih_bias = params[2] if has_biases else None
 | |
|     hh_bias = params[3] if has_biases else None
 | |
| 
 | |
|     precomputed_input = F.linear(inp, ih_weight, ih_bias)
 | |
|     precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
 | |
|     cur_hidden = hidden.unsqueeze(0)
 | |
|     step_output = []
 | |
|     for i in precomputed_input:
 | |
|         cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
 | |
|         step_output.append(cur_hidden)
 | |
| 
 | |
|     if reverse:
 | |
|         step_output.reverse()
 | |
| 
 | |
|     out = torch.cat(step_output, 0)
 | |
| 
 | |
|     return out, cur_hidden.squeeze(0)
 | |
| 
 | |
| 
 | |
| def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
 | |
|     w0 = params[0]
 | |
|     w1 = params[1]
 | |
|     if has_biases:
 | |
|         w2 = params[2]
 | |
|         w3 = params[3]
 | |
|     else:
 | |
|         w2 = torch.zeros(w0.size())
 | |
|         w3 = torch.zeros(w1.size())
 | |
| 
 | |
|     hx = hidden[0].unsqueeze(0)
 | |
|     cx = hidden[1].unsqueeze(0)
 | |
| 
 | |
|     batch_sizes: list[int] = []
 | |
|     mode = 2  # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2
 | |
|     hidden_size = hx.size(2)
 | |
|     num_layers = 1
 | |
| 
 | |
|     # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here
 | |
|     bidirectional = False
 | |
|     batch_first = False
 | |
| 
 | |
|     train = False
 | |
|     # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here.
 | |
|     # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous();
 | |
|     inp = inp.contiguous()
 | |
|     hx = hx.contiguous()
 | |
|     cx = cx.contiguous()
 | |
|     outputs = torch.ops.aten.mkldnn_rnn_layer.default(
 | |
|         inp,
 | |
|         w0,
 | |
|         w1,
 | |
|         w2,
 | |
|         w3,
 | |
|         hx,
 | |
|         cx,
 | |
|         reverse,
 | |
|         batch_sizes,
 | |
|         mode,
 | |
|         hidden_size,
 | |
|         num_layers,
 | |
|         has_biases,
 | |
|         bidirectional,
 | |
|         batch_first,
 | |
|         train,
 | |
|     )
 | |
|     y, hy, cy = outputs[0], outputs[1], outputs[2]
 | |
|     return y, (hy.squeeze(0), cy.squeeze(0))
 | |
| 
 | |
| 
 | |
| def _rnn_helper(
 | |
|     input,
 | |
|     hidden,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
|     batch_first,
 | |
|     layer_fn,
 | |
| ):
 | |
|     input = input.transpose(0, 1) if batch_first else input
 | |
|     final_hiddens = []
 | |
| 
 | |
|     for i in range(num_layers):
 | |
|         cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens(
 | |
|             params, hidden, i, bidirectional
 | |
|         )
 | |
|         dropout = dropout if (train and num_layers < i - 1) else 0.0
 | |
|         fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases)
 | |
|         final_hiddens.append(fwd_hidden)
 | |
| 
 | |
|         if bidirectional:
 | |
|             bwd_inp, bwd_hidden = layer_fn(
 | |
|                 input, bidir_hidden, bidir_params, has_biases, reverse=True
 | |
|             )
 | |
|             final_hiddens.append(bwd_hidden)
 | |
| 
 | |
|         if bidirectional:
 | |
|             input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1)  # type: ignore[possibly-undefined]
 | |
|         else:
 | |
|             input = fwd_inp
 | |
| 
 | |
|         if dropout != 0 and train and i < num_layers - 1:
 | |
|             input = torch.dropout(input, dropout, train=True)
 | |
| 
 | |
|     input = input.transpose(0, 1) if batch_first else input
 | |
|     return input, final_hiddens
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.rnn_tanh.input)
 | |
| @aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.rnn_tanh.input.py_impl(DispatchKey.Autograd)
 | |
| def rnn_tanh_input(
 | |
|     input,
 | |
|     hx,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
|     batch_first,
 | |
| ):
 | |
|     hidden = hx.unbind(0)
 | |
|     params = gather_params(params, has_biases, False)
 | |
|     out, final_hiddens = _rnn_helper(
 | |
|         input,
 | |
|         hidden,
 | |
|         params,
 | |
|         has_biases,
 | |
|         num_layers,
 | |
|         dropout,
 | |
|         train,
 | |
|         bidirectional,
 | |
|         batch_first,
 | |
|         partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)),
 | |
|     )
 | |
|     return out, torch.stack(final_hiddens, 0)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.rnn_relu.input)
 | |
| @aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.rnn_relu.input.py_impl(DispatchKey.Autograd)
 | |
| def rnn_relu_input(
 | |
|     input,
 | |
|     hx,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
|     batch_first,
 | |
| ):
 | |
|     hidden = hx.unbind(0)
 | |
|     params = gather_params(params, has_biases, False)
 | |
|     out, final_hiddens = _rnn_helper(
 | |
|         input,
 | |
|         hidden,
 | |
|         params,
 | |
|         has_biases,
 | |
|         num_layers,
 | |
|         dropout,
 | |
|         train,
 | |
|         bidirectional,
 | |
|         batch_first,
 | |
|         partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)),
 | |
|     )
 | |
|     return out, torch.stack(final_hiddens, 0)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.rnn_relu.data)
 | |
| @aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.rnn_relu.data.py_impl(DispatchKey.Autograd)
 | |
| def rnn_relu_data(
 | |
|     data,
 | |
|     batch_sizes,
 | |
|     hx,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
| ):
 | |
|     hidden = hx.unbind(0)
 | |
|     params = gather_params(params, has_biases, False)
 | |
|     out, final_hiddens = _rnn_helper(
 | |
|         data,
 | |
|         hidden,
 | |
|         params,
 | |
|         has_biases,
 | |
|         num_layers,
 | |
|         dropout,
 | |
|         train,
 | |
|         bidirectional,
 | |
|         False,
 | |
|         partial(
 | |
|             one_layer_rnn_data,
 | |
|             batch_sizes=batch_sizes,
 | |
|             hidden_fn=rnn_cell_data(torch.relu),
 | |
|         ),
 | |
|     )
 | |
|     return out, torch.stack(final_hiddens, 0)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.rnn_tanh.data)
 | |
| @aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.rnn_tanh.data.py_impl(DispatchKey.Autograd)
 | |
| def rnn_tanh_data(
 | |
|     data,
 | |
|     batch_sizes,
 | |
|     hx,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
| ):
 | |
|     hidden = hx.unbind(0)
 | |
|     params = gather_params(params, has_biases, False)
 | |
|     out, final_hiddens = _rnn_helper(
 | |
|         data,
 | |
|         hidden,
 | |
|         params,
 | |
|         has_biases,
 | |
|         num_layers,
 | |
|         dropout,
 | |
|         train,
 | |
|         bidirectional,
 | |
|         False,
 | |
|         partial(
 | |
|             one_layer_rnn_data,
 | |
|             batch_sizes=batch_sizes,
 | |
|             hidden_fn=rnn_cell_data(torch.tanh),
 | |
|         ),
 | |
|     )
 | |
|     return out, torch.stack(final_hiddens, 0)
 | |
| 
 | |
| 
 | |
| def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim):
 | |
|     gates = F.linear(hx, hh_weight, hh_bias) + inp
 | |
|     chunked_gates = gates.chunk(4, chunk_dim)
 | |
|     in_gate = chunked_gates[0].sigmoid()
 | |
|     forget_gate = chunked_gates[1].sigmoid()
 | |
|     cell_gate = chunked_gates[2].tanh()
 | |
|     out_gate = chunked_gates[3].sigmoid()
 | |
|     cy = forget_gate * cx + (in_gate * cell_gate)
 | |
|     hy = out_gate * cy.tanh()
 | |
|     hy = hy if hr_weight is None else F.linear(hy, hr_weight, None)
 | |
| 
 | |
|     return hy, cy
 | |
| 
 | |
| 
 | |
| def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
 | |
|     ih_weight = params[0]
 | |
|     hh_weight = params[1]
 | |
|     ih_bias = params[2] if has_biases else None
 | |
|     hh_bias = params[3] if has_biases else None
 | |
|     hr_weight = (
 | |
|         params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
 | |
|     )
 | |
| 
 | |
|     hx = hidden[0].unsqueeze(0)
 | |
|     cx = hidden[1].unsqueeze(0)
 | |
| 
 | |
|     precomputed_input = F.linear(inp, ih_weight, ih_bias)
 | |
|     precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
 | |
|     step_output = []
 | |
|     for inp in precomputed_input:
 | |
|         hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2)
 | |
|         step_output.append(hx)
 | |
| 
 | |
|     if reverse:
 | |
|         step_output.reverse()
 | |
| 
 | |
|     out = torch.cat(step_output, 0)
 | |
| 
 | |
|     return out, (hx.squeeze(1), cx.squeeze(1))
 | |
| 
 | |
| 
 | |
| def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False):
 | |
|     ih_weight = params[0]
 | |
|     hh_weight = params[1]
 | |
|     ih_bias = params[2] if has_biases else None
 | |
|     hh_bias = params[3] if has_biases else None
 | |
|     hr_weight = (
 | |
|         params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
 | |
|     )
 | |
| 
 | |
|     step_output = []
 | |
|     hiddens = []
 | |
| 
 | |
|     last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
 | |
|     split_inp = torch.split(inp, list(batch_sizes))
 | |
|     if reverse:
 | |
|         split_inp = split_inp[::-1]
 | |
| 
 | |
|     orig_hx = hidden[0]
 | |
|     orig_cx = hidden[1]
 | |
|     hx, cx = (
 | |
|         orig_hx.narrow(0, 0, last_batch_size),
 | |
|         orig_cx.narrow(0, 0, last_batch_size),
 | |
|     )
 | |
| 
 | |
|     for inp in split_inp:
 | |
|         i = inp.shape[0]
 | |
|         inp = F.linear(inp, ih_weight, ih_bias)
 | |
| 
 | |
|         # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
 | |
|         if i < last_batch_size:
 | |
|             hiddens.append(
 | |
|                 (
 | |
|                     hx.narrow(0, i, last_batch_size - i),
 | |
|                     cx.narrow(0, i, last_batch_size - i),
 | |
|                 )
 | |
|             )
 | |
|             hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i)
 | |
| 
 | |
|         # this will only happen when reverse=True
 | |
|         if i > last_batch_size:
 | |
|             hx = torch.concat(
 | |
|                 (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0
 | |
|             )
 | |
|             cx = torch.concat(
 | |
|                 (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0
 | |
|             )
 | |
| 
 | |
|         hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1)
 | |
|         last_batch_size = i
 | |
|         step_output.append(hx)
 | |
| 
 | |
|     if reverse:
 | |
|         step_output.reverse()
 | |
|         hidden_out = (hx, cx)
 | |
|     else:
 | |
|         hiddens.append((hx, cx))
 | |
|         hiddens.reverse()
 | |
|         hidden0, hidden1 = zip(*hiddens)
 | |
|         hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0)
 | |
| 
 | |
|     out = torch.cat(step_output, 0)
 | |
|     return out, hidden_out
 | |
| 
 | |
| 
 | |
| def select_one_layer_lstm_function(input, hx, params):
 | |
|     r"""Check whether we could use decompose lstm with mkldnn_rnn_layer.
 | |
|     All the below conditions need to be met:
 | |
|         * ``torch._C._get_mkldnn_enabled()`` returns ``True``.
 | |
|         * All the input args are on CPU.
 | |
|         * The dtypes of args are either torch.float or torch.bfloat16.
 | |
|         * Inference.
 | |
|         * ``has_projections`` returns ``False``.
 | |
| 
 | |
|     Args:
 | |
|         * input: the input sequence to LSTM
 | |
|         * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM
 | |
|         * params: the weight and bias tensors of LSTM
 | |
|     """
 | |
| 
 | |
|     def use_mkldnn(input, hx, params):
 | |
|         if not torch._C._get_mkldnn_enabled():
 | |
|             return False
 | |
| 
 | |
|         tensors = [input] + list(hx) + list(chain.from_iterable(params))
 | |
|         devices = {t.device for t in tensors}
 | |
|         if len(devices) != 1:
 | |
|             return False
 | |
| 
 | |
|         device = devices.pop()
 | |
|         if device != torch.device("cpu"):
 | |
|             return False
 | |
|         # With autocast, possible to have mixed dtype here
 | |
|         dtypes = {t.dtype for t in tensors}
 | |
|         for dtype in dtypes:
 | |
|             if dtype not in [torch.float, torch.bfloat16]:
 | |
|                 return False
 | |
| 
 | |
|         if input.requires_grad:
 | |
|             return False
 | |
| 
 | |
|         has_projections = hx[0].size(2) != hx[1].size(2)
 | |
|         if has_projections:
 | |
|             return False
 | |
| 
 | |
|         return True
 | |
| 
 | |
|     # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm
 | |
|     # will expand over the seq_len dim
 | |
|     if use_mkldnn(input, hx, params):
 | |
|         return mkldnn_one_layer_lstm
 | |
|     else:
 | |
|         return one_layer_lstm
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.lstm.input)
 | |
| @aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.lstm.input.py_impl(DispatchKey.Autograd)
 | |
| def lstm_impl(
 | |
|     input,
 | |
|     hx,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
|     batch_first,
 | |
| ):
 | |
|     assert len(hx) == 2, "lstm expects two hidden states"
 | |
|     params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
 | |
|     hidden = list(zip(hx[0], hx[1]))
 | |
|     layer_fn = select_one_layer_lstm_function(input, hx, params)
 | |
|     out, final_hiddens = _rnn_helper(
 | |
|         input,
 | |
|         hidden,
 | |
|         params,
 | |
|         has_biases,
 | |
|         num_layers,
 | |
|         dropout,
 | |
|         train,
 | |
|         bidirectional,
 | |
|         batch_first,
 | |
|         layer_fn,
 | |
|     )
 | |
|     final_hiddens = list(zip(*final_hiddens))
 | |
|     return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.lstm.data)
 | |
| @aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.lstm.data.py_impl(DispatchKey.Autograd)
 | |
| def lstm_data_impl(
 | |
|     data,
 | |
|     batch_sizes,
 | |
|     hx,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
| ):
 | |
|     assert len(hx) == 2, "lstm expects two hidden states"
 | |
|     params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
 | |
|     hidden = list(zip(hx[0], hx[1]))
 | |
|     out, final_hiddens = _rnn_helper(
 | |
|         data,
 | |
|         hidden,
 | |
|         params,
 | |
|         has_biases,
 | |
|         num_layers,
 | |
|         dropout,
 | |
|         train,
 | |
|         bidirectional,
 | |
|         False,
 | |
|         partial(one_layer_lstm_data, batch_sizes=batch_sizes),
 | |
|     )
 | |
|     final_hiddens = list(zip(*final_hiddens))
 | |
|     return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
 | |
| 
 | |
| 
 | |
| def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
 | |
|     chunked_igates = inp.chunk(3, 1)
 | |
|     chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2)
 | |
|     reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
 | |
|     input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
 | |
|     new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
 | |
|     return (cur_hidden - new_gate) * input_gate + new_gate
 | |
| 
 | |
| 
 | |
| def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
 | |
|     chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1)
 | |
|     chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1)
 | |
|     reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
 | |
|     input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
 | |
|     new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
 | |
|     return (cur_hidden - new_gate) * input_gate + new_gate
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.gru.data)
 | |
| @aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.gru.data.py_impl(DispatchKey.Autograd)
 | |
| def gru_impl_data(
 | |
|     data,
 | |
|     batch_sizes,
 | |
|     hx,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
| ):
 | |
|     params = gather_params(params, has_biases, False)
 | |
|     out, final_hiddens = _rnn_helper(
 | |
|         data,
 | |
|         hx.unbind(0),
 | |
|         params,
 | |
|         has_biases,
 | |
|         num_layers,
 | |
|         dropout,
 | |
|         train,
 | |
|         bidirectional,
 | |
|         False,
 | |
|         partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data),
 | |
|     )
 | |
|     return out, torch.stack(final_hiddens, 0)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.gru.input)
 | |
| @aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.gru.input.py_impl(DispatchKey.Autograd)
 | |
| def gru_impl(
 | |
|     input,
 | |
|     hx,
 | |
|     params,
 | |
|     has_biases,
 | |
|     num_layers,
 | |
|     dropout,
 | |
|     train,
 | |
|     bidirectional,
 | |
|     batch_first,
 | |
| ):
 | |
|     params = gather_params(params, has_biases, False)
 | |
|     out, final_hiddens = _rnn_helper(
 | |
|         input,
 | |
|         hx.unbind(0),
 | |
|         params,
 | |
|         has_biases,
 | |
|         num_layers,
 | |
|         dropout,
 | |
|         train,
 | |
|         bidirectional,
 | |
|         batch_first,
 | |
|         partial(one_layer_rnn, hidden_fn=gru_cell),
 | |
|     )
 | |
|     return out, torch.stack(final_hiddens, 0)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._upsample_bilinear2d_aa.vec)
 | |
| @aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd)
 | |
| def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors):
 | |
|     osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
 | |
|     scale_h = get_scale_value(scale_factors, 0)
 | |
|     scale_w = get_scale_value(scale_factors, 1)
 | |
|     return torch.ops.aten._upsample_bilinear2d_aa(
 | |
|         input, osize, align_corners, scale_h, scale_w
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten._upsample_bicubic2d_aa.vec)
 | |
| @aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd)
 | |
| def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors):
 | |
|     osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
 | |
|     scale_h = get_scale_value(scale_factors, 0)
 | |
|     scale_w = get_scale_value(scale_factors, 1)
 | |
|     return torch.ops.aten._upsample_bicubic2d_aa(
 | |
|         input, osize, align_corners, scale_h, scale_w
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.upsample_bilinear2d.vec)
 | |
| @register_decomposition(aten.upsample_trilinear3d.vec)
 | |
| @aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd)
 | |
| @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)
 | |
| @aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd)
 | |
| def _upsample_linear_vec(input, output_size, align_corners, scale_factors):
 | |
|     osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
 | |
|     scales = scale_factors if scale_factors else [None] * len(osize)
 | |
|     return _upsample_linear(input, osize, align_corners, scales)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out])
 | |
| @out_wrapper()
 | |
| def upsample_linear1d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     align_corners: bool,
 | |
|     scales_w: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_linear(input, output_size, align_corners, [scales_w])
 | |
| 
 | |
| 
 | |
| @register_decomposition(
 | |
|     [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out]
 | |
| )
 | |
| @aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper()
 | |
| def upsample_bilinear2d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     align_corners: bool,
 | |
|     scales_h: Optional[float] = None,
 | |
|     scales_w: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w])
 | |
| 
 | |
| 
 | |
| @register_decomposition(
 | |
|     [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out]
 | |
| )
 | |
| @out_wrapper()
 | |
| def upsample_trilinear3d(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     align_corners: bool,
 | |
|     scales_d: Optional[float] = None,
 | |
|     scales_h: Optional[float] = None,
 | |
|     scales_w: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     return _upsample_linear(
 | |
|         input, output_size, align_corners, [scales_d, scales_h, scales_w]
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _compute_scale(in_size, out_size, align_corners, scale=None):
 | |
|     if align_corners:
 | |
|         return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0
 | |
|     else:
 | |
|         return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size
 | |
| 
 | |
| 
 | |
| def _compute_source_index(scale, dst_index, align_corners):
 | |
|     if align_corners:
 | |
|         return scale * dst_index
 | |
|     else:
 | |
|         return scale * (dst_index + 0.5) - 0.5
 | |
| 
 | |
| 
 | |
| def _sum_tensors_uint8(
 | |
|     src: Iterable[Tensor], weights: Iterable[Tensor], weights_precision: Tensor
 | |
| ) -> Tensor:
 | |
|     output = _sum_tensors(
 | |
|         s.to(torch.int32) * c.to(torch.int32) for s, c in zip(src, weights)
 | |
|     ) + (1 << (weights_precision - 1))
 | |
|     output = output >> weights_precision
 | |
|     return torch.clamp(output, 0, 255).to(torch.uint8)
 | |
| 
 | |
| 
 | |
| def _compute_weight_precision(weights: TensorSequenceType) -> Tensor:
 | |
|     max_weight = torch.stack(weights).max()
 | |
|     max_weight_precision = 22
 | |
|     precisions = torch.arange(max_weight_precision, device=max_weight.device)
 | |
|     values = 0.5 + max_weight * (1 << (precisions + 1))
 | |
|     mask = values >= (1 << 15)
 | |
|     return max_weight_precision - mask.sum()
 | |
| 
 | |
| 
 | |
| @pw_cast_for_opmath
 | |
| def _upsample_linear(
 | |
|     input: Tensor,
 | |
|     output_size: list[int],
 | |
|     align_corners: bool,
 | |
|     scales: list[Optional[float]],
 | |
| ) -> Tensor:
 | |
|     # get dimensions of original image
 | |
|     n_channels = input.shape[1]
 | |
|     inp_sizes = input.shape[2:]
 | |
|     n_dims = len(inp_sizes)
 | |
| 
 | |
|     _, dtype = utils.elementwise_dtypes(
 | |
|         input,
 | |
|         type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
 | |
|     )
 | |
| 
 | |
|     def get_values(inp_size, out_size, scales, nsqueeze):
 | |
|         # First Calculate scaling factor
 | |
|         scale_factor = _compute_scale(inp_size, out_size, align_corners, scales)
 | |
|         # We have to create arange with int64 dtype and use .to in order to avoid
 | |
|         # additional kernels creation in inductor and get a perf slowdown
 | |
|         i = torch.arange(out_size, device=input.device).to(dtype=dtype)
 | |
| 
 | |
|         x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0)
 | |
|         x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze))
 | |
|         x = x_f32.to(torch.int64)
 | |
|         xp1 = (x + 1).clamp(max=inp_size - 1)
 | |
|         return x_f32, x, xp1
 | |
| 
 | |
|     values = [
 | |
|         get_values(inp_size, out_size, scales, n_dims - 1 - i)
 | |
|         for i, (inp_size, out_size, scales) in enumerate(
 | |
|             zip(inp_sizes, output_size, scales)
 | |
|         )
 | |
|     ]
 | |
|     xs_f32, xs, xp1s = list(zip(*values))
 | |
| 
 | |
|     vs = []
 | |
|     for a in product(*[[0, 1]] * n_dims):
 | |
|         idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)]
 | |
|         v = aten._unsafe_index(input, idx)
 | |
|         v = _maybe_convert_to_dtype(v, dtype)
 | |
|         vs.append(v)
 | |
| 
 | |
|     for i in reversed(range(n_dims)):
 | |
|         xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype)
 | |
|         vs = [
 | |
|             # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha
 | |
|             v1 + torch.mul(v2 - v1, xscale)
 | |
|             for v1, v2 in zip(vs[::2], vs[1::2])
 | |
|         ]
 | |
| 
 | |
|     assert len(vs) == 1
 | |
|     result = vs[0]
 | |
| 
 | |
|     # convert output to correct memory format, if necessary
 | |
|     memory_format = utils.suggest_memory_format(input)
 | |
| 
 | |
|     # following "heuristic: only use channels_last path when it's faster than the contiguous path"
 | |
|     if input.device.type == "cuda" and n_channels < 16:
 | |
|         memory_format = torch.contiguous_format
 | |
| 
 | |
|     assert isinstance(result, torch.Tensor)
 | |
| 
 | |
|     result = result.contiguous(memory_format=memory_format)
 | |
| 
 | |
|     if not input.is_floating_point():
 | |
|         result = result.round()
 | |
| 
 | |
|     return result
 | |
| 
 | |
| 
 | |
| # We should be applying decompositions after all transformations
 | |
| @register_decomposition(aten.is_same_size.default)
 | |
| def is_same_size(a: Tensor, b: Tensor) -> bool:
 | |
|     return a.shape == b.shape
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten._reshape_alias, aten._unsafe_view])
 | |
| @out_wrapper()
 | |
| def _reshape_alias(x, shape, *args):
 | |
|     return aten.view(x, shape)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten._unsafe_index])
 | |
| def _unsafe_index(x, indices):
 | |
|     return aten.index(x, indices)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten._unsafe_index_put])
 | |
| def _unsafe_index_put(x, indices, value, accumulate=False):
 | |
|     return aten.index_put(x, indices, value, accumulate)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten._unsafe_masked_index])
 | |
| def _unsafe_masked_index(x, mask, indices, fill):
 | |
|     for index in indices:
 | |
|         if index is not None:
 | |
|             torch._check(
 | |
|                 index.dtype in [torch.long, torch.int],
 | |
|                 lambda: "tensors used as indices must be long or int tensors",
 | |
|             )
 | |
| 
 | |
|     torch._check(
 | |
|         mask.dtype == torch.bool,
 | |
|         lambda: "tensors used as masks must be bool tensors",
 | |
|     )
 | |
| 
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_or_false
 | |
| 
 | |
|     if guard_or_false(x.numel() == 0):
 | |
|         meta_result = torch._meta_registrations.meta_index_Tensor(x, indices)
 | |
|         return x.new_full(meta_result.shape, fill)
 | |
| 
 | |
|     for i in range(len(indices)):
 | |
|         index = indices[i]
 | |
|         if index is not None:
 | |
|             indices[i] = index.clamp(min=0, max=x.size(i) - 1)
 | |
| 
 | |
|     return aten._unsafe_index(x, indices).masked_fill(~mask, fill)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten._unsafe_masked_index_put_accumulate])
 | |
| def _unsafe_masked_index_put_accumulate(x, mask, indices, values):
 | |
|     for index in indices:
 | |
|         if index is not None:
 | |
|             torch._check(
 | |
|                 index.dtype in [torch.long, torch.int],
 | |
|                 lambda: "tensors used as indices must be long or int tensors",
 | |
|             )
 | |
| 
 | |
|     torch._check(
 | |
|         mask.dtype == torch.bool,
 | |
|         lambda: "tensors used as masks must be bool tensors",
 | |
|     )
 | |
| 
 | |
|     if x.numel() == 0:
 | |
|         return x.clone()
 | |
| 
 | |
|     for i in range(len(indices)):
 | |
|         index = indices[i]
 | |
|         if index is not None:
 | |
|             indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1)
 | |
| 
 | |
|     masked_value = values.masked_fill(~mask, 0)
 | |
|     return aten._unsafe_index_put(x, indices, masked_value, accumulate=True)
 | |
| 
 | |
| 
 | |
| def _nll_loss_forward(
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     reduction: int,
 | |
|     ignore_index: int,
 | |
| ) -> tuple[Tensor, Tensor]:
 | |
|     # self can be [N, C] or [C]
 | |
|     # target can be [N] or []
 | |
| 
 | |
|     n_dims = self.dim()
 | |
|     channel_dim = 1
 | |
|     if n_dims < 2:
 | |
|         channel_dim = 0
 | |
| 
 | |
|     if weight is not None:
 | |
|         if n_dims > 1:
 | |
|             shape = [
 | |
|                 1,
 | |
|             ] * n_dims
 | |
|             shape[channel_dim] = weight.shape[0]
 | |
|             w = weight.view(shape)
 | |
|         else:
 | |
|             w = weight
 | |
|         self = self * w
 | |
|     safe_target = torch.where(target != ignore_index, target, 0)
 | |
|     safe_target_ = safe_target.unsqueeze(channel_dim)
 | |
|     # target can be [N, 1] or [1]
 | |
| 
 | |
|     result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
 | |
| 
 | |
|     result = torch.where(target != ignore_index, result, 0)
 | |
| 
 | |
|     if reduction == Reduction.NONE.value and n_dims > 1:
 | |
|         total_weight = self.new_full((), 0.0)
 | |
|         return result, total_weight
 | |
| 
 | |
|     if weight is not None:
 | |
|         # pyrefly: ignore  # unbound-name
 | |
|         w = w.expand(self.shape)
 | |
|         wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
 | |
|         wsum = torch.where(target != ignore_index, wsum, 0)
 | |
|         total_weight = wsum.sum()
 | |
|     else:
 | |
|         total_weight = (target != ignore_index).sum().to(self)
 | |
| 
 | |
|     if reduction == Reduction.SUM.value:
 | |
|         result = result.sum()
 | |
|     elif reduction == Reduction.MEAN.value:
 | |
|         result = result.sum() / total_weight
 | |
| 
 | |
|     return result, total_weight
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.nll_loss_forward)
 | |
| @out_wrapper("output", "total_weight")
 | |
| def nll_loss_forward(
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     reduction: int,
 | |
|     ignore_index: int,
 | |
| ) -> tuple[Tensor, Tensor]:
 | |
|     assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D"
 | |
|     assert target.dim() <= 1, (
 | |
|         "0D or 1D target tensor expected, multi-target not supported"
 | |
|     )
 | |
| 
 | |
|     no_batch_dim = self.dim() == 1 and target.dim() == 0
 | |
|     assert no_batch_dim or (self.shape[0] == target.shape[0]), (
 | |
|         f"size mismatch (got input: {self.shape}, target: {target.shape})"
 | |
|     )
 | |
| 
 | |
|     n_classes = self.shape[-1]
 | |
| 
 | |
|     assert weight is None or (weight.dim() == 1 and weight.numel() == n_classes), (
 | |
|         f"weight tensor should be defined either for all {n_classes} classes or no classes "
 | |
|         f"but got weight tensor of shape: {weight.shape}"
 | |
|     )
 | |
| 
 | |
|     return _nll_loss_forward(self, target, weight, reduction, ignore_index)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.nll_loss2d_forward)
 | |
| @out_wrapper("output", "total_weight")
 | |
| def nll_loss2d_forward(
 | |
|     self: Tensor,
 | |
|     target: Tensor,
 | |
|     weight: Optional[Tensor],
 | |
|     reduction: int,
 | |
|     ignore_index: int,
 | |
| ) -> tuple[Tensor, Tensor]:
 | |
|     return _nll_loss_forward(self, target, weight, reduction, ignore_index)
 | |
| 
 | |
| 
 | |
| # These are adapted from aten/src/ATen/native/UpSample.h, which is based on
 | |
| # https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
 | |
| def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor:
 | |
|     return ((A + 2) * x - (A + 3)) * x * x + 1
 | |
| 
 | |
| 
 | |
| def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor:
 | |
|     return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A
 | |
| 
 | |
| 
 | |
| def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType:
 | |
|     A = -0.75
 | |
| 
 | |
|     if t.device == torch.device("cpu"):
 | |
|         tt1 = torch.stack([t, 1.0 - t], dim=0)
 | |
|         tt2 = torch.stack([t + 1.0, 2.0 - t], dim=0)
 | |
|         w03 = _upsample_cubic_convolution2(tt2, A)
 | |
|         w12 = _upsample_cubic_convolution1(tt1, A)
 | |
|         w0, w3 = torch.unbind(w03, dim=0)
 | |
|         w1, w2 = torch.unbind(w12, dim=0)
 | |
|         return w0, w1, w2, w3
 | |
|     else:
 | |
|         return (
 | |
|             _upsample_cubic_convolution2(t + 1.0, A),
 | |
|             _upsample_cubic_convolution1(t, A),
 | |
|             _upsample_cubic_convolution1(1.0 - t, A),
 | |
|             _upsample_cubic_convolution2(2.0 - t, A),
 | |
|         )
 | |
| 
 | |
| 
 | |
| def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor:
 | |
|     coeffs2 = _upsample_get_cubic_coefficients(ts)
 | |
|     return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2))
 | |
| 
 | |
| 
 | |
| # Need this instead of just sum() to keep mypy happy
 | |
| def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
 | |
|     return reduce(torch.add, ts)
 | |
| 
 | |
| 
 | |
| def _linspace_from_neg_one(
 | |
|     num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device
 | |
| ):
 | |
|     if num_steps <= 1:
 | |
|         return torch.tensor(0, device=device, dtype=dtype)
 | |
| 
 | |
|     a = ((num_steps - 1) / num_steps) if not align_corners else 1
 | |
|     return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype)
 | |
| 
 | |
| 
 | |
| def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool):
 | |
|     dtype = theta.dtype
 | |
|     device = theta.device
 | |
| 
 | |
|     # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated
 | |
|     # corresponding to each individual tensor: grid_x, grid_y, grid_one
 | |
|     grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1)
 | |
|     grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1)
 | |
|     grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device)
 | |
| 
 | |
|     # this is just a temporary hack and we should use torch.stack here once #104480 is merged
 | |
|     grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0)
 | |
|     grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0)
 | |
|     grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0)
 | |
|     return grid_x + grid_y + grid_one
 | |
| 
 | |
| 
 | |
| def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool):
 | |
|     dtype = theta.dtype
 | |
|     device = theta.device
 | |
| 
 | |
|     grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1)
 | |
|     grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1)
 | |
|     grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1)
 | |
|     grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device)
 | |
| 
 | |
|     # this is just a temporary hack and we should use torch.stack here once #104480 is merged
 | |
|     grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0)
 | |
|     grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0)
 | |
|     grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0)
 | |
|     grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0)
 | |
|     return grid_x + grid_y + grid_z + grid_one
 | |
| 
 | |
| 
 | |
| def _affine_grid_generator_4d(theta: Tensor, size: list[int], align_corners: bool):
 | |
|     n, _, h, w = size
 | |
|     base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners)
 | |
|     # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3)
 | |
|     # We do manually a matrix multiplication which is faster than mm()
 | |
|     # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2)
 | |
|     grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2)
 | |
|     return grid.view(n, h, w, 2)
 | |
| 
 | |
| 
 | |
| def _affine_grid_generator_5d(theta: Tensor, size: list[int], align_corners: bool):
 | |
|     n, _, d, h, w = size
 | |
|     base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners)
 | |
|     # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4)
 | |
|     # We do manually a matrix multiplication which is faster than mm()
 | |
|     # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3)
 | |
|     grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2)
 | |
|     return grid.view(n, d, h, w, 3)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.affine_grid_generator)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def affine_grid_generator(theta: Tensor, size: list[int], align_corners: bool):
 | |
|     torch._check(
 | |
|         len(size) in (4, 5),
 | |
|         lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.",
 | |
|     )
 | |
|     if len(size) == 4:
 | |
|         return _affine_grid_generator_4d(theta, size, align_corners=align_corners)
 | |
|     else:
 | |
|         return _affine_grid_generator_5d(theta, size, align_corners=align_corners)
 | |
| 
 | |
| 
 | |
| def _grid_sampler_2d(
 | |
|     a: Tensor,
 | |
|     grid: Tensor,
 | |
|     interpolation_mode: int = 0,
 | |
|     padding_mode: int = 0,
 | |
|     align_corners: bool = False,
 | |
|     _expand_grid: bool = True,
 | |
| ) -> Tensor:
 | |
|     # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to
 | |
|     # optionally expand the input grid for performance reasons.
 | |
|     # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
 | |
|     # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
 | |
|     # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
 | |
|     # Thus we apply this hack to not expand the grid for this case.
 | |
| 
 | |
|     torch._check(
 | |
|         interpolation_mode in (0, 1, 2),
 | |
|         lambda: f"Invalid interpolation mode {interpolation_mode}",
 | |
|     )
 | |
|     torch._check(
 | |
|         padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
 | |
|     )
 | |
| 
 | |
|     def unnormalize(coords: Tensor, size: int) -> Tensor:
 | |
|         # Rescale coordinates from [-1, 1] to:
 | |
|         #   [0, size - 1] if align_corners is True
 | |
|         #   [-.5, size -.5] if align_corners is False
 | |
|         mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
 | |
|         ofs = size * 0.5 - 0.5
 | |
|         return coords * mul + ofs
 | |
| 
 | |
|     # Reflects coordinates until they fall between low and high (inclusive).
 | |
|     # The bounds are passed as twice their value so that half-integer values
 | |
|     # can be represented as ints.
 | |
|     def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
 | |
|         if twice_low == twice_high:
 | |
|             return torch.zeros_like(coords)
 | |
|         coords_min = twice_low / 2
 | |
|         coords_span = (twice_high - twice_low) / 2
 | |
|         coords2 = (coords - coords_min).abs()
 | |
|         extra = torch.fmod(coords2, coords_span)
 | |
|         flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
 | |
|         return torch.where(
 | |
|             flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
 | |
|         )
 | |
| 
 | |
|     def compute_coordinates(coords: Tensor, size: int) -> Tensor:
 | |
|         if padding_mode == 0:  # Zero
 | |
|             return coords
 | |
|         elif padding_mode == 1:  # Borders
 | |
|             return torch.clamp(coords, 0, size - 1)
 | |
|         else:  # padding_mode == 2, Reflection
 | |
|             if align_corners:
 | |
|                 coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
 | |
|             else:
 | |
|                 coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
 | |
|             return torch.clamp(coords_reflected, 0, size - 1)
 | |
| 
 | |
|     def compute_source_index(coords: Tensor, size: int) -> Tensor:
 | |
|         coords_un = unnormalize(coords, size)
 | |
|         return compute_coordinates(coords_un, size)
 | |
| 
 | |
|     N, C, iH, iW = a.shape
 | |
|     _, oH, oW, two = grid.shape
 | |
|     assert two == 2
 | |
| 
 | |
|     if _expand_grid:
 | |
|         # Let's expand grid to [N, C, oH, oW, 2]
 | |
|         # This allows to generate a single triton cuda kernel instead of two kernels.
 | |
|         # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW
 | |
|         # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW
 | |
|         # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW
 | |
|         grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2)
 | |
| 
 | |
|     def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor:
 | |
|         return torch.logical_and(
 | |
|             0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH))
 | |
|         )
 | |
| 
 | |
|     N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1)
 | |
|     C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1)
 | |
| 
 | |
|     def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType:
 | |
|         cond = in_bounds_cond(xs, ys)
 | |
|         # To clip to inside valid coordinates, we map the coordinates
 | |
|         # to (x, y) = (0, 0) and also set the weight to 0
 | |
|         # We also change the shape of the tensor to the appropriate one for
 | |
|         # broadcasting with N_idx, C_idx for the purposes of advanced indexing
 | |
|         c = C if _expand_grid else 1
 | |
|         return tuple(
 | |
|             torch.where(cond, t, 0).view(N, c, oH, oW)
 | |
|             for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws)
 | |
|         )
 | |
| 
 | |
|     def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor:
 | |
|         # Perform clipping, index into input tensor and multiply by weight
 | |
|         idx_x, idx_y, w_ = clip(ix, iy, w)
 | |
|         return a[N_idx, C_idx, idx_y, idx_x] * w_
 | |
| 
 | |
|     x = grid[..., 0]
 | |
|     y = grid[..., 1]
 | |
| 
 | |
|     if interpolation_mode == 0:  # Bilinear
 | |
|         ix = compute_source_index(x, iW)
 | |
|         iy = compute_source_index(y, iH)
 | |
| 
 | |
|         ix_nw, iy_nw = ix.floor(), iy.floor()
 | |
|         ix_ne, iy_ne = ix_nw + 1, iy_nw
 | |
|         ix_sw, iy_sw = ix_nw, iy_nw + 1
 | |
|         ix_se, iy_se = ix_ne, iy_sw
 | |
| 
 | |
|         w_nw = (ix_se - ix) * (iy_se - iy)
 | |
|         w_ne = (ix - ix_sw) * (iy_sw - iy)
 | |
|         w_sw = (ix_ne - ix) * (iy - iy_ne)
 | |
|         w_se = (ix - ix_nw) * (iy - iy_nw)
 | |
| 
 | |
|         return _sum_tensors(
 | |
|             get_summand(ix, iy, w)
 | |
|             for (ix, iy, w) in (
 | |
|                 (ix_nw, iy_nw, w_nw),
 | |
|                 (ix_ne, iy_ne, w_ne),
 | |
|                 (ix_sw, iy_sw, w_sw),
 | |
|                 (ix_se, iy_se, w_se),
 | |
|             )
 | |
|         )
 | |
|     elif interpolation_mode == 1:  # Nearest
 | |
|         ix = compute_source_index(x, iW)
 | |
|         iy = compute_source_index(y, iH)
 | |
| 
 | |
|         ix_nearest = ix.round()
 | |
|         iy_nearest = iy.round()
 | |
| 
 | |
|         return get_summand(ix_nearest, iy_nearest, 1)
 | |
|     else:  # interpolation_mode == 2, Bicubic
 | |
|         ix = unnormalize(x, iW)
 | |
|         iy = unnormalize(y, iH)
 | |
| 
 | |
|         ix_nw = ix.floor()
 | |
|         iy_nw = iy.floor()
 | |
| 
 | |
|         tx = ix - ix_nw
 | |
|         ty = iy - iy_nw
 | |
| 
 | |
|         if not _expand_grid:
 | |
|             tx = tx.unsqueeze(1)
 | |
|             ty = ty.unsqueeze(1)
 | |
| 
 | |
|         def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor:
 | |
|             x = compute_coordinates(ix, iW)
 | |
|             y = compute_coordinates(iy, iH)
 | |
|             return get_summand(x, y, 1)
 | |
| 
 | |
|         def get_coeff(ofs: int) -> Tensor:
 | |
|             iy_ofs = iy_nw + (ofs - 1)
 | |
|             cs = (
 | |
|                 get_value_bounded(ix_nw - 1, iy_ofs),
 | |
|                 get_value_bounded(ix_nw, iy_ofs),
 | |
|                 get_value_bounded(ix_nw + 1, iy_ofs),
 | |
|                 get_value_bounded(ix_nw + 2, iy_ofs),
 | |
|             )
 | |
|             return _upsample_cubic_interp1d(cs, tx)
 | |
| 
 | |
|         coeffs = tuple(get_coeff(ofs) for ofs in range(4))
 | |
|         return _upsample_cubic_interp1d(coeffs, ty)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.grid_sampler_2d)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def grid_sampler_2d(
 | |
|     a: Tensor,
 | |
|     grid: Tensor,
 | |
|     interpolation_mode: int = 0,
 | |
|     padding_mode: int = 0,
 | |
|     align_corners: bool = False,
 | |
| ) -> Tensor:
 | |
|     return _grid_sampler_2d(
 | |
|         a,
 | |
|         grid=grid,
 | |
|         interpolation_mode=interpolation_mode,
 | |
|         padding_mode=padding_mode,
 | |
|         align_corners=align_corners,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.mv)
 | |
| @out_wrapper(exact_dtype=True)
 | |
| @pw_cast_for_opmath
 | |
| def mv(self, vec):
 | |
|     torch._check(
 | |
|         self.dim() == 2 and vec.dim() == 1,
 | |
|         lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
 | |
|     )
 | |
|     torch._check(
 | |
|         self.size(1) == vec.size(0),
 | |
|         lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
 | |
|     )
 | |
|     return (self * vec).sum(dim=1)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.binary_cross_entropy_with_logits)
 | |
| @out_wrapper()
 | |
| def binary_cross_entropy_with_logits(
 | |
|     self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value
 | |
| ):
 | |
|     if pos_weight is not None:
 | |
|         log_weight = (pos_weight - 1) * target + 1
 | |
|         loss = (1 - target) * self - (log_weight * F.logsigmoid(self))
 | |
|     else:
 | |
|         loss = (1 - target) * self - F.logsigmoid(self)
 | |
| 
 | |
|     if weight is not None:
 | |
|         loss = loss * weight
 | |
| 
 | |
|     return apply_loss_reduction(loss, reduction)
 | |
| 
 | |
| 
 | |
| def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool:
 | |
|     # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp
 | |
| 
 | |
|     t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1)
 | |
| 
 | |
|     from torch.fx.experimental.symbolic_shapes import guard_or_false
 | |
| 
 | |
|     if not (t1.ndim >= 3 and t2.ndim <= 2):
 | |
|         return False
 | |
|     if t2.requires_grad and not is_out:
 | |
|         return True
 | |
|     if tensor1.ndim == 2:
 | |
|         return False
 | |
|     if guard_or_false(t1.numel() == 0):
 | |
|         return True
 | |
| 
 | |
|     t1_shape = t1.shape
 | |
|     t1_stride = t1.stride()
 | |
| 
 | |
|     # Check the contiguous, we can skip the dim with size of 1
 | |
|     # as aten: https://github.com/pytorch/pytorch/blob/e201460f8aa1510b4c4686627d57b69756c4b916/aten/src/ATen/TensorGeometry.cpp#L17
 | |
|     expected_stride = [1]
 | |
|     for size in reversed(t1_shape[1:]):
 | |
|         expected_stride.append(size * expected_stride[-1])
 | |
|     return all(
 | |
|         guard_or_false(size == 1) or guard_or_false(left == right)
 | |
|         for left, right, size in zip(
 | |
|             t1_stride, list(reversed(expected_stride)), t1_shape
 | |
|         )
 | |
|     )
 | |
| 
 | |
| 
 | |
| @aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @out_wrapper(pass_is_out=True)
 | |
| def matmul(tensor1, tensor2, *, is_out=False):
 | |
|     dim_tensor1 = tensor1.dim()
 | |
|     dim_tensor2 = tensor2.dim()
 | |
|     assert dim_tensor1 != 0 and dim_tensor2 != 0
 | |
|     if dim_tensor1 == 1 and dim_tensor2 == 1:
 | |
|         return torch.dot(tensor1, tensor2)
 | |
|     elif dim_tensor1 == 2 and dim_tensor2 == 1:
 | |
|         return torch.mv(tensor1, tensor2)
 | |
|     elif dim_tensor1 == 1 and dim_tensor2 == 2:
 | |
|         return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)
 | |
|     elif dim_tensor1 == 2 and dim_tensor2 == 2:
 | |
|         return torch.mm(tensor1, tensor2)
 | |
|     elif should_fold(tensor1, tensor2, is_out):
 | |
|         # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
 | |
|         # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
 | |
|         # and some condition on the strides is fulfilled
 | |
| 
 | |
|         # optimization: use mm instead of bmm by folding the batch of the larger tensor
 | |
|         # into its leading matrix dimension
 | |
|         transpose = dim_tensor2 > dim_tensor1
 | |
|         t1 = tensor2.mT if transpose else tensor1
 | |
|         t2 = (
 | |
|             tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1)
 | |
|         )
 | |
|         # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2)
 | |
|         #            and t1 and t2 are matmul-compatible
 | |
| 
 | |
|         # Why not t1.view(-1, sizes_1[-1])?
 | |
|         # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
 | |
|         # This can happen in e.g. [3, 5, 0] @ [0, 0].
 | |
|         sizes_1 = t1.shape
 | |
|         output_shape = list(sizes_1[:-1])
 | |
|         folded_dim1 = reduce(operator.mul, output_shape)
 | |
| 
 | |
|         # Readjust output_shape if we are multiplying by a matrix
 | |
|         t2_is_matrix = t2.dim() == 2
 | |
|         if t2_is_matrix:
 | |
|             output_shape.append(t2.shape[1])
 | |
| 
 | |
|         # This will almost always be a view.
 | |
|         # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation
 | |
|         t1_folded = t1.reshape(folded_dim1, sizes_1[-1])
 | |
|         if t2_is_matrix:
 | |
|             # This copies if we perform a 2D @ 3D and the first tensor requires_grad
 | |
|             # See should_fold native/LinearAlgebra.cpp for why.
 | |
|             output = torch.ops.aten._unsafe_view(t1_folded.mm(t2), output_shape)
 | |
|             return output.mT.contiguous() if transpose else output
 | |
|         else:
 | |
|             return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape)
 | |
| 
 | |
|     elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
 | |
|         # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
 | |
|         # we track m1 vs m2 separately even though they must match for nicer error messages
 | |
|         n = tensor1.size(-2) if dim_tensor1 > 1 else 1
 | |
|         m1 = tensor1.size(-1)
 | |
|         batch_tensor1 = tensor1.shape[:-2]
 | |
|         m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
 | |
|         p = tensor2.size(-1) if dim_tensor2 > 1 else 1
 | |
| 
 | |
|         batch_tensor2: list[int] = []
 | |
|         # TODO: handling of slice
 | |
|         for i in range(dim_tensor2 - 2):
 | |
|             batch_tensor2.append(tensor2.size(i))
 | |
| 
 | |
|         # Same optimization for the gradients as that in should_fold
 | |
|         # If we're going to broadcast, we force it to go through the should_fold branch
 | |
|         if (
 | |
|             dim_tensor1 == 3
 | |
|             and dim_tensor2 == 3
 | |
|             and batch_tensor1[0] != batch_tensor2[0]
 | |
|         ):
 | |
|             if batch_tensor1[0] == 1 and tensor1.requires_grad:
 | |
|                 return matmul(tensor1.squeeze(0), tensor2)
 | |
|             if batch_tensor2[0] == 1 and tensor2.requires_grad:
 | |
|                 return matmul(tensor1, tensor2.squeeze(0))
 | |
| 
 | |
|         # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
 | |
|         expand_batch_portion = list(
 | |
|             torch.broadcast_shapes(batch_tensor1, batch_tensor2)
 | |
|         )
 | |
| 
 | |
|         tensor1_expand_size = expand_batch_portion + [n, m1]
 | |
| 
 | |
|         expand_batch_product = prod(expand_batch_portion)
 | |
| 
 | |
|         # HACK: We need reshape with symint support
 | |
|         tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(
 | |
|             expand_batch_product, n, m1
 | |
|         )
 | |
| 
 | |
|         vector_rhs = dim_tensor2 == 1
 | |
|         if vector_rhs:
 | |
|             tensor2_expand_size = expand_batch_portion + [m2]
 | |
|             tensor2_expanded = (
 | |
|                 tensor2.expand(tensor2_expand_size)
 | |
|                 .reshape(expand_batch_product, m2)
 | |
|                 .unsqueeze(2)
 | |
|             )
 | |
|         else:
 | |
|             tensor2_expand_size = expand_batch_portion + [m2, p]
 | |
|             tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(
 | |
|                 expand_batch_product, m2, p
 | |
|             )
 | |
| 
 | |
|         output_shape = expand_batch_portion
 | |
|         if dim_tensor1 > 1:
 | |
|             output_shape.append(n)
 | |
| 
 | |
|         if dim_tensor2 > 1:
 | |
|             output_shape.append(p)
 | |
| 
 | |
|         if vector_rhs:
 | |
|             return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape)
 | |
|         else:
 | |
|             return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
 | |
|     else:
 | |
|         torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out])
 | |
| @aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def upsample_bicubic2d_default(
 | |
|     input: Tensor,
 | |
|     output_size: tuple[int, int],
 | |
|     align_corners: bool,
 | |
|     scale_h: Optional[float] = None,
 | |
|     scale_w: Optional[float] = None,
 | |
| ) -> Tensor:
 | |
|     # get dimensions of original image
 | |
|     _, _, in_h, in_w = input.shape
 | |
| 
 | |
|     # Calculate horizontal and vertical scaling factor
 | |
|     h_scale_factor = _compute_scale(in_h, output_size[0], align_corners, scale_h)
 | |
|     w_scale_factor = _compute_scale(in_w, output_size[1], align_corners, scale_w)
 | |
| 
 | |
|     _, dtype = utils.elementwise_dtypes(
 | |
|         input, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
 | |
|     )
 | |
| 
 | |
|     # We have to create arange with int64 dtype and use .to in order to avoid
 | |
|     # additional kernels creation in inductor and get a perf slowdown
 | |
|     i = torch.arange(output_size[0], device=input.device).to(dtype=dtype)
 | |
|     j = torch.arange(output_size[1], device=input.device).to(dtype=dtype)
 | |
| 
 | |
|     x_float = _compute_source_index(w_scale_factor, j, align_corners)
 | |
|     y_float = _compute_source_index(h_scale_factor, i, align_corners)
 | |
|     y_float = y_float.unsqueeze(-1)
 | |
| 
 | |
|     x = x_float.floor()
 | |
|     y = y_float.floor()
 | |
| 
 | |
|     # We should also clamp xscale/yscale
 | |
|     # See guard_index_and_lambda in UpSample.h
 | |
|     yscale = (y_float - y).clamp(0.0, 1.0)
 | |
|     xscale = (x_float - x).clamp(0.0, 1.0)
 | |
|     x = x.to(torch.int64)
 | |
|     y = y.to(torch.int64)
 | |
| 
 | |
|     iys_ofs = (y - 1, y, y + 1, y + 2)
 | |
|     ixs_ofs = (x - 1, x, x + 1, x + 2)
 | |
| 
 | |
|     weights_x = _upsample_get_cubic_coefficients(xscale)
 | |
|     weights_y = _upsample_get_cubic_coefficients(yscale)
 | |
| 
 | |
|     weights_precision_x, weights_precision_y = None, None
 | |
|     if input.dtype == torch.uint8:
 | |
|         weights_precision_x = _compute_weight_precision(weights_x)
 | |
|         weights_precision_y = _compute_weight_precision(weights_y)
 | |
| 
 | |
|         weights_x = [
 | |
|             (w * (1 << weights_precision_x) + torch.sign(w) * 0.5).to(torch.int16)
 | |
|             for w in weights_x
 | |
|         ]
 | |
|         weights_y = [
 | |
|             (w * (1 << weights_precision_y) + torch.sign(w) * 0.5).to(torch.int16)
 | |
|             for w in weights_y
 | |
|         ]
 | |
| 
 | |
|     def load_bounded(ys, xs):
 | |
|         y_idx = torch.clamp(ys, 0, in_h - 1)
 | |
|         x_idx = torch.clamp(xs, 0, in_w - 1)
 | |
|         v = aten._unsafe_index(input, [None, None, y_idx, x_idx])
 | |
|         return v
 | |
| 
 | |
|     def get_x_interp(y):
 | |
|         src_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs)
 | |
|         if input.dtype == torch.uint8:
 | |
|             assert weights_precision_x is not None
 | |
|             return _sum_tensors_uint8(src_x, weights_x, weights_precision_x)
 | |
|         return _sum_tensors(c1 * c2 for (c1, c2) in zip(src_x, weights_x))
 | |
| 
 | |
|     src_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs)
 | |
|     if input.dtype == torch.uint8:
 | |
|         assert weights_precision_y is not None
 | |
|         result = _sum_tensors_uint8(src_y, weights_y, weights_precision_y)
 | |
|     else:
 | |
|         result = _sum_tensors(c1 * c2 for (c1, c2) in zip(src_y, weights_y))
 | |
| 
 | |
|     # convert output to correct memory format, if necessary
 | |
|     memory_format = utils.suggest_memory_format(input)
 | |
|     result = result.contiguous(memory_format=memory_format)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.upsample_bicubic2d.vec)
 | |
| @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
 | |
| @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper()
 | |
| @pw_cast_for_opmath
 | |
| def upsample_bicubic2d_vec(
 | |
|     a: Tensor,
 | |
|     output_size: Optional[tuple[int, int]],
 | |
|     align_corners: bool,
 | |
|     scale_factors: Optional[tuple[float, float]] = None,
 | |
| ) -> Tensor:
 | |
|     torch._check(
 | |
|         bool(output_size) + bool(scale_factors) == 1,
 | |
|         lambda: "Must specify exactly one of output_size and scale_factors.",
 | |
|     )
 | |
|     if output_size is None:
 | |
|         assert scale_factors is not None
 | |
|         output_size = cast(
 | |
|             tuple[int, int],
 | |
|             tuple(
 | |
|                 sym_int(sym_float(w) * scale)
 | |
|                 for w, scale in zip(a.shape[2:], scale_factors)
 | |
|             ),
 | |
|         )
 | |
|     scale_h, scale_w = scale_factors if scale_factors else (None, None)
 | |
|     return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.reflection_pad1d)
 | |
| @register_decomposition(aten.reflection_pad2d)
 | |
| @register_decomposition(aten.reflection_pad3d)
 | |
| @pw_cast_for_opmath
 | |
| @out_wrapper()
 | |
| def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
 | |
|     def idx(left, middle, right):
 | |
|         dim_idx = torch.arange(-left, middle + right, device=a.device)
 | |
|         return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
 | |
| 
 | |
|     return _reflection_or_replication_pad(
 | |
|         a,
 | |
|         padding,
 | |
|         idx,
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.replication_pad1d)
 | |
| @register_decomposition(aten.replication_pad2d)
 | |
| @register_decomposition(aten.replication_pad3d)
 | |
| @pw_cast_for_opmath
 | |
| @out_wrapper()
 | |
| def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
 | |
|     def idx(left, middle, right):
 | |
|         dim_idx = torch.arange(-left, middle + right, device=a.device)
 | |
|         return torch.clamp(dim_idx, 0, middle - 1)
 | |
| 
 | |
|     return _reflection_or_replication_pad(
 | |
|         a,
 | |
|         padding,
 | |
|         idx,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _reflection_or_replication_pad(
 | |
|     a: Tensor,
 | |
|     padding: tuple[int, ...],
 | |
|     idx_fn: Callable[[int, int, int], Tensor],
 | |
| ) -> Tensor:
 | |
|     dim = len(padding) // 2
 | |
|     torch._check(
 | |
|         a.dim() in (dim + 1, dim + 2),
 | |
|         lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
 | |
|     )
 | |
|     inp_shape = a.shape[-dim:]
 | |
|     nc_dim = a.dim() - dim
 | |
| 
 | |
|     padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
 | |
|     padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
 | |
| 
 | |
|     result = a
 | |
|     for i in range(dim):
 | |
|         idx: list[Any] = [None] * result.dim()
 | |
|         idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
 | |
|         result = aten._unsafe_index(result, idx)
 | |
| 
 | |
|     # convert output to correct memory format, if necessary
 | |
|     memory_format = utils.suggest_memory_format(result)
 | |
|     result = result.contiguous(memory_format=memory_format)
 | |
|     return result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.reflection_pad1d_backward)
 | |
| @register_decomposition(aten.reflection_pad2d_backward)
 | |
| @register_decomposition(aten.reflection_pad3d_backward)
 | |
| @out_wrapper("grad_input")
 | |
| def _reflection_pad_backward(grad_output, x, padding):
 | |
|     dim = len(padding) // 2
 | |
| 
 | |
|     dhw = [h - 1 for h in x.shape[-dim:]]
 | |
| 
 | |
|     padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
 | |
|     padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
 | |
| 
 | |
|     indices = []
 | |
|     for i in range(x.ndim):
 | |
|         view_shape = [1] * x.ndim
 | |
|         view_shape[i] = -1
 | |
|         indices.append(torch.arange(x.shape[i], device=x.device).view(view_shape))
 | |
| 
 | |
|     b = indices[:-dim]
 | |
|     xyz = indices[-dim:]
 | |
| 
 | |
|     def index_range_condition(index_range):
 | |
|         i, lb, ub = index_range
 | |
|         return torch.logical_and(i >= lb, i <= ub)
 | |
| 
 | |
|     # Areas after reflection:
 | |
|     #
 | |
|     #   top-left    |   top     |   top-right
 | |
|     # -----------------------------------------
 | |
|     #   left        |   center  |   right
 | |
|     # -----------------------------------------
 | |
|     #   bottom-left |   bottom  |   bottom-right
 | |
|     #
 | |
|     # The center area is the original matrix. Other areas are reflections.
 | |
| 
 | |
|     center = [xyz[i] + padding_left[i] for i in range(dim)]
 | |
|     left_reflect = [padding_left[i] - xyz[i] for i in range(dim)]
 | |
|     right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)]
 | |
| 
 | |
|     # Accumulate gradients from different areas
 | |
|     # If some of the padding is negative, center load is not always valid
 | |
|     range_c = [
 | |
|         (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) for i in range(dim)
 | |
|     ]
 | |
|     cond = functools.reduce(
 | |
|         aten.logical_and, [index_range_condition(range_c[i]) for i in range(dim)]
 | |
|     )
 | |
|     grad = aten._unsafe_masked_index(grad_output, cond, b + center, 0.0)
 | |
| 
 | |
|     def accumulate(grad, out, index_ranges):
 | |
|         # If the upper bound is less than the lower bound, we can get rid of one accumulation.
 | |
|         # This happens when the padding size is zero.
 | |
|         for i in range(dim):
 | |
|             upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1]
 | |
|             if isinstance(upper_less_than_lower, bool) and upper_less_than_lower:
 | |
|                 return grad
 | |
| 
 | |
|         cond = functools.reduce(
 | |
|             aten.logical_and,
 | |
|             [index_range_condition(index_range) for index_range in index_ranges],
 | |
|         )
 | |
|         g = aten._unsafe_masked_index(grad_output, cond, b + out, 0.0)
 | |
|         return grad + g
 | |
| 
 | |
|     for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]):
 | |
|         if area == tuple([0] * dim):
 | |
|             # center, this is already done.
 | |
|             continue
 | |
| 
 | |
|         outs = []
 | |
|         index_ranges = []
 | |
| 
 | |
|         for i in range(dim):
 | |
|             if area[i] == 0:
 | |
|                 out = center[i]
 | |
|                 index_range = range_c[i]
 | |
|             elif area[i] == -1:
 | |
|                 out = left_reflect[i]
 | |
|                 index_range = (xyz[i], 1, padding_left[i])
 | |
|             elif area[i] == 1:
 | |
|                 out = right_reflect[i]
 | |
|                 index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1)
 | |
| 
 | |
|             outs.append(out)  # type: ignore[possibly-undefined]
 | |
|             index_ranges.append(index_range)  # type: ignore[possibly-undefined]
 | |
| 
 | |
|         grad = accumulate(grad, outs, index_ranges)
 | |
| 
 | |
|     return grad
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.aminmax)
 | |
| @out_wrapper("min", "max")
 | |
| def aminmax(self, *, dim=None, keepdim=False):
 | |
|     # pyrefly: ignore  # bad-argument-type
 | |
|     amin = torch.amin(self, dim=dim, keepdim=keepdim)
 | |
|     # pyrefly: ignore  # bad-argument-type
 | |
|     amax = torch.amax(self, dim=dim, keepdim=keepdim)
 | |
|     return amin, amax
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.nansum)
 | |
| @out_wrapper()
 | |
| def nansum(self, dim=None, keepdim=False, *, dtype=None):
 | |
|     return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.arange.default, aten.arange.out])
 | |
| @out_wrapper()
 | |
| def arange_default(
 | |
|     end: NumberType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[torch.device] = None,
 | |
|     pin_memory: bool = False,
 | |
| ):
 | |
|     return aten.arange.start_step(
 | |
|         0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.arange.start])
 | |
| def arange_start(
 | |
|     start: NumberType,
 | |
|     end: NumberType,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     layout: torch.layout = torch.strided,
 | |
|     device: Optional[torch.device] = None,
 | |
|     pin_memory: bool = False,
 | |
| ):
 | |
|     return aten.arange.start_step(
 | |
|         start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
 | |
|     )
 | |
| 
 | |
| 
 | |
| @register_decomposition(out_dtype)
 | |
| def out_dtype_decomp(*args, **kwargs):
 | |
|     from torch._higher_order_ops.out_dtype import out_dtype_dense
 | |
| 
 | |
|     return out_dtype_dense(*args, **kwargs)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.multi_margin_loss)
 | |
| @aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper()
 | |
| def multi_margin_loss(
 | |
|     input: Tensor,
 | |
|     target: Tensor,
 | |
|     p: NumberType = 1,
 | |
|     margin: NumberType = 1,
 | |
|     weight: Optional[Tensor] = None,
 | |
|     reduction: int = Reduction.MEAN.value,
 | |
| ) -> Tensor:
 | |
|     input = torch.atleast_2d(input)
 | |
|     target = torch.atleast_1d(target)
 | |
|     nframe = input.shape[0]
 | |
|     dim = input.shape[1]
 | |
|     torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
 | |
|     torch._check(
 | |
|         input.ndim == 2 and dim != 0,
 | |
|         lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}",
 | |
|     )
 | |
|     torch._check(
 | |
|         target.ndim == 1 and target.numel() == nframe,
 | |
|         lambda: f"inconsistent target size, expected {nframe} but got {target.shape}",
 | |
|     )
 | |
|     if weight is not None:
 | |
|         weight = torch.atleast_1d(weight)
 | |
|         torch._check(
 | |
|             weight.ndim == 1 and weight.numel() == dim,  # type: ignore[union-attr]
 | |
|             lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}",  # type: ignore[union-attr]
 | |
|         )
 | |
|     target = target.unsqueeze(1)
 | |
|     u = torch.gather(input, dim=1, index=target)
 | |
|     z = margin - u + input
 | |
|     z = z.clamp_min(0)
 | |
|     z = z if p == 1 else z * z
 | |
|     if weight is not None:
 | |
|         z = z * weight[target]
 | |
|     idx = torch.arange(dim, device=input.device)
 | |
|     z = torch.where(idx != target, z, 0)
 | |
|     if reduction == Reduction.MEAN.value:
 | |
|         return z.mean()
 | |
|     elif reduction == Reduction.SUM.value:
 | |
|         return z.sum() / z.shape[1]
 | |
|     else:
 | |
|         return z.mean(dim=1)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.multilabel_margin_loss_forward)
 | |
| @aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd)
 | |
| @out_wrapper("output", "is_target")
 | |
| def multilabel_margin_loss_forward(
 | |
|     input: Tensor,
 | |
|     target: Tensor,
 | |
|     reduction: int,
 | |
| ) -> tuple[Tensor, Tensor]:
 | |
|     orig_input_shape = input.shape
 | |
|     orig_target_shape = target.shape
 | |
|     input = torch.atleast_2d(input)
 | |
|     target = torch.atleast_2d(target)
 | |
|     dim = input.shape[1]
 | |
|     torch._check(
 | |
|         len(orig_input_shape) <= 2 and dim != 0,
 | |
|         lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}",
 | |
|     )
 | |
|     torch._check(
 | |
|         len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape,
 | |
|         lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}",
 | |
|     )
 | |
|     # ignores labels after the first -1, detects when -1 is not present
 | |
|     idx = torch.arange(dim, device=target.device)
 | |
|     is_end = target == -1
 | |
|     end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True)
 | |
|     # target indices
 | |
|     target_mask = idx < end_idx
 | |
|     # masks target to be able to use gather, which doesn't allow -1
 | |
|     tidx0 = torch.where(target_mask, target, 0)
 | |
|     u = torch.gather(input, dim=-1, index=tidx0)
 | |
|     # is_target
 | |
|     tidx1 = torch.where(target_mask, target, -1)
 | |
|     is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1)
 | |
|     # loss
 | |
|     z = 1.0 - u.T.unsqueeze(dim=-1) + input
 | |
|     z = z.clamp_min(0)
 | |
|     z = z / dim
 | |
|     # masks loss
 | |
|     z = torch.where(is_target, 0, z)
 | |
|     # reduction
 | |
|     if reduction == Reduction.MEAN.value:
 | |
|         z = z.sum(dim=(0, -1)).mean()
 | |
|     elif reduction == Reduction.SUM.value:
 | |
|         z = z.sum()
 | |
|     else:
 | |
|         z = z.sum(dim=(0, -1))
 | |
|     # result
 | |
|     is_target = is_target.to(input.dtype).reshape(orig_target_shape)
 | |
|     return z, is_target
 | |
| 
 | |
| 
 | |
| # scaled_dot_product_attention used to be decomposed in pre-autograd, given that
 | |
| # it calls _scaled_dot_product_attention_math and
 | |
| # _scaled_dot_product_attention_math only has a CompositeImplicitAutograd
 | |
| # kernel. As a result it's decomposed into ops with finer granularity.
 | |
| # However recent PRs (#103826 #105131 #115913) added new logic in
 | |
| # scaled_dot_product_attention and now it calls
 | |
| # _scaled_dot_product_flash_attention_for_cpu in export path. This results
 | |
| # in _scaled_dot_product_flash_attention_for_cpu showing up in export result.
 | |
| # This decomposition ensures scaled_dot_product_attention is still decomposed
 | |
| # the same way as before, i.e., going through
 | |
| # _scaled_dot_product_attention_math. Notice that this decomp rule should be
 | |
| # excluded by inductor.
 | |
| @register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default)
 | |
| def scaled_dot_product_flash_attention_for_cpu(
 | |
|     query: Tensor,
 | |
|     key: Tensor,
 | |
|     value: Tensor,
 | |
|     dropout_p: float = 0.0,
 | |
|     is_causal: bool = False,
 | |
|     *,
 | |
|     attn_mask: Optional[Tensor] = None,
 | |
|     scale: Optional[float] = None,
 | |
| ) -> tuple[Tensor, Tensor]:
 | |
|     torch._check(
 | |
|         torch.is_floating_point(query),
 | |
|         lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",
 | |
|     )
 | |
|     torch._check(
 | |
|         query.dim() == 4 and key.dim() == 4 and value.dim() == 4,
 | |
|         lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}",
 | |
|     )
 | |
|     torch._check(
 | |
|         dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}"
 | |
|     )
 | |
|     torch._check(
 | |
|         query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3],
 | |
|         lambda: "q, k, v should have the same head size",
 | |
|     )
 | |
| 
 | |
|     output, attn = aten._scaled_dot_product_attention_math.default(
 | |
|         query,
 | |
|         key,
 | |
|         value,
 | |
|         attn_mask=attn_mask,
 | |
|         dropout_p=dropout_p,
 | |
|         is_causal=is_causal,
 | |
|         dropout_mask=None,
 | |
|         scale=scale,
 | |
|         enable_gqa=query.size(1) != key.size(1),
 | |
|     )
 | |
|     # Why this change?
 | |
|     # In pre-dispatch export scaled_dot_product_attention is executed via
 | |
|     # * flash_attention.
 | |
|     # flash_attention allocates output tensor as (N, H, L, E) (see PR #134656)
 | |
|     # assume x: [N, H, L, E] is the output sdpa
 | |
|     # In MHA code, this output is then permuted via (2, 0, 1, 3) to get
 | |
|     # (L, N, H, E) dim tensor
 | |
|     # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via
 | |
|     # x = x.view(L * N, H * E)
 | |
|     # During pre autograd dispatch call to contiguous is not traced because
 | |
|     # flash_attention output after the x.permute is already contiguous
 | |
|     # on which the view is valid
 | |
|     # However, during 2nd stage export, post-dispatch, we run _match variant
 | |
|     # instead of flash* to get the decomposition. _match variant returns
 | |
|     # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns
 | |
|     # x: [L, N, H, E] and without converting this to contiguous tensor
 | |
|     # subsequent view is not valid and the export fails
 | |
|     # solution is to maintain the return tensor view from the decomp to be
 | |
|     # exactly same as *flash* variant.
 | |
| 
 | |
|     # Really the invariant you want to maintain is:
 | |
|     # pre-dispatch op-output and its decomposed representation must
 | |
|     # return tensor with same view and dims
 | |
|     output = (
 | |
|         output.permute(2, 0, 1, 3)
 | |
|         .contiguous(memory_format=torch.contiguous_format)
 | |
|         .permute(1, 2, 0, 3)
 | |
|     )
 | |
|     return output, attn
 | |
| 
 | |
| 
 | |
| def register_inplace(aten_op, outplace_op):
 | |
|     @register_decomposition(aten_op)
 | |
|     def inplace_op(*args, **kwargs):
 | |
|         out = outplace_op(*args, **kwargs)
 | |
|         return args[0].copy_(out)
 | |
| 
 | |
|     return inplace_op
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.baddbmm])
 | |
| @out_wrapper(exact_dtype=True)
 | |
| @pw_cast_for_opmath
 | |
| def baddbmm(self, batch1, batch2, beta=1, alpha=1):
 | |
|     if not self.is_floating_point() and not self.is_complex():
 | |
|         beta = int(beta)
 | |
|         alpha = int(alpha)
 | |
|     result = torch.bmm(batch1, batch2)
 | |
|     if not isinstance(alpha, numbers.Number) or alpha != 1:
 | |
|         # pyrefly: ignore  # unsupported-operation
 | |
|         result = result * alpha
 | |
|     if beta == 0:
 | |
|         return result
 | |
|     if not isinstance(beta, numbers.Number) or beta != 1:
 | |
|         self = self * beta
 | |
|     return self + result
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.floor_divide)
 | |
| @out_wrapper()
 | |
| def floor_divide(self, other):
 | |
|     return torch.div(self, other, rounding_mode="floor")
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.sym_numel)
 | |
| def sym_numel(t):
 | |
|     return functools.reduce(operator.mul, t.shape, 1)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.sum.default, aten.sum.out])
 | |
| def sum_default(
 | |
|     self: Tensor,
 | |
|     *,
 | |
|     dtype: Optional[torch.dtype] = None,
 | |
|     out: Optional[Tensor] = None,
 | |
| ) -> Tensor:
 | |
|     if out is None:
 | |
|         return aten.sum.dim_IntList(self, [], dtype=dtype)
 | |
|     else:
 | |
|         return aten.sum.IntList_out(self, [], dtype=dtype, out=out)
 | |
| 
 | |
| 
 | |
| @register_decomposition([aten.squeeze.default, aten.squeeze.dim])
 | |
| def squeeze_default(self: Tensor, dim: Optional[int] = None):
 | |
|     # handle a scalar directly
 | |
|     if not isinstance(self, torch.Tensor):
 | |
|         return self
 | |
|     # perform squeeze
 | |
|     if dim is None:
 | |
|         return aten.squeeze.dims(self, list(range(self.dim())))
 | |
|     else:
 | |
|         return aten.squeeze.dims(self, [dim])
 | |
| 
 | |
| 
 | |
| @register_decomposition(torch.ops.aten._weight_norm_interface)
 | |
| def _weight_norm_interface(v, g, dim=0):
 | |
|     # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58
 | |
|     keep_dim = tuple(i for i in range(len(v.shape)) if i != dim)
 | |
|     # align with cuda behavior, keep norm in 'float' when g is 'bfloat16'
 | |
|     norm_dtype = torch.float if g.dtype == torch.bfloat16 else None
 | |
|     norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype)
 | |
|     return v * (g / norm.to(g.dtype)), norm
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.isin)
 | |
| @out_wrapper()
 | |
| def isin(elements, test_elements, *, assume_unique=False, invert=False):
 | |
|     # handle when either elements or test_elements are Scalars (they can't both be)
 | |
|     if not isinstance(elements, torch.Tensor):
 | |
|         elements = torch.tensor(elements, device=test_elements.device)
 | |
|     if not isinstance(test_elements, torch.Tensor):
 | |
|         if invert:
 | |
|             return torch.ne(elements, test_elements)
 | |
|         else:
 | |
|             return torch.eq(elements, test_elements)
 | |
| 
 | |
|     if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145):
 | |
|         return isin_default(elements, test_elements, invert=invert)
 | |
|     else:
 | |
|         return isin_sorting(
 | |
|             elements, test_elements, assume_unique=assume_unique, invert=invert
 | |
|         )
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.bernoulli.default)
 | |
| def bernoulli(
 | |
|     self: torch.Tensor,
 | |
|     *,
 | |
|     generator: Optional[torch.Generator] = None,
 | |
| ) -> torch.Tensor:
 | |
|     if generator is None:
 | |
|         raw_p = torch.rand(self.size(), dtype=torch.float32, device=self.device)
 | |
|     else:
 | |
|         raw_p = torch.rand(
 | |
|             self.size(),
 | |
|             generator=generator,
 | |
|             dtype=torch.float32,
 | |
|             device=self.device,
 | |
|         )
 | |
|     p = (raw_p < self).to(self.dtype)
 | |
|     return p
 | |
| 
 | |
| 
 | |
| def isin_default(elements, test_elements, *, invert=False):
 | |
|     if elements.numel() == 0:
 | |
|         return torch.empty_like(elements, dtype=torch.bool)
 | |
|     expanded_elem_shape = elements.shape + (1,) * test_elements.ndim
 | |
|     x = elements.view(expanded_elem_shape)
 | |
|     dim = tuple(range(-1, -test_elements.ndim - 1, -1))
 | |
|     res = (x == test_elements).any(dim=dim)
 | |
|     return ~res if invert else res
 | |
| 
 | |
| 
 | |
| def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False):
 | |
|     elements_flat = elements.flatten()
 | |
|     test_elements_flat = test_elements.flatten()
 | |
|     if assume_unique:
 | |
|         # This is the same as the aten implementation. For
 | |
|         # assume_unique=False, we cannot use unique() here, so we use a
 | |
|         # version with searchsorted instead.
 | |
|         all_elements = torch.cat([elements_flat, test_elements_flat])
 | |
|         sorted_elements, sorted_order = torch.sort(all_elements, stable=True)
 | |
| 
 | |
|         duplicate_mask = sorted_elements[1:] == sorted_elements[:-1]
 | |
|         duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False)
 | |
| 
 | |
|         if invert:
 | |
|             duplicate_mask = duplicate_mask.logical_not()
 | |
| 
 | |
|         mask = torch.empty_like(duplicate_mask)
 | |
|         mask = mask.index_copy(0, sorted_order, duplicate_mask)
 | |
| 
 | |
|         return mask[0 : elements.numel()]
 | |
|     else:
 | |
|         sorted_test_elements, _ = torch.sort(test_elements_flat)
 | |
|         idx = torch.searchsorted(sorted_test_elements, elements_flat)
 | |
|         test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0)
 | |
|         cmp = sorted_test_elements[test_idx] == elements_flat
 | |
|         cmp = cmp.logical_not() if invert else cmp
 | |
|         return cmp.reshape(elements.shape)
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.take)
 | |
| @out_wrapper()
 | |
| def take(self, index):
 | |
|     flattened = self.reshape(-1)
 | |
|     return flattened[index]
 | |
| 
 | |
| 
 | |
| @register_decomposition(aten.resize_as)
 | |
| def resize_as(self, other, memory_format=None):
 | |
|     if memory_format is None:
 | |
|         memory_format = torch.contiguous_format
 | |
|     if memory_format == torch.preserve_format:
 | |
|         memory_format = suggest_memory_format(other)
 | |
|     return aten.resize(self, other.shape, memory_format=memory_format)
 | |
| 
 | |
| 
 | |
| register_inplace(aten.addbmm_, aten.addbmm)
 | |
| register_inplace(aten.addmm_, aten.addmm)
 | |
| register_inplace(aten.addmv_, aten.addmv)
 | |
| register_inplace(aten.baddbmm_, aten.baddbmm)
 | |
| register_inplace(aten.fill_, aten.fill)
 | |
| register_inplace(aten.gelu_, aten.gelu)
 | |
| register_inplace(aten.hardswish_, aten.hardswish)
 | |
| register_inplace(aten.hardtanh_, aten.hardtanh)
 | |
| register_inplace(aten.hardsigmoid_, aten.hardsigmoid)
 | |
| register_inplace(aten.__iand__, aten.__and__)
 | |
| register_inplace(aten.__ilshift__, aten.__lshift__)
 | |
| register_inplace(aten.index_put_, aten.index_put)
 | |
| register_inplace(aten.index_reduce_, aten.index_reduce)
 | |
| register_inplace(aten.__ior__, aten.__or__)
 | |
| register_inplace(aten.__irshift__, aten.__rshift__)
 | |
| register_inplace(aten.__ixor__, aten.__xor__)
 | |
| register_inplace(aten.leaky_relu_, aten.leaky_relu)
 | |
| register_inplace(aten.logit_, aten.logit)
 | |
| register_inplace(aten.relu_, aten.relu)
 | |
| register_inplace(aten.renorm_, aten.renorm)
 | |
| register_inplace(aten.round_, aten.round)
 | |
| register_inplace(aten.scatter_, aten.scatter)
 | |
| register_inplace(aten.scatter_add_, aten.scatter_add)
 | |
| register_inplace(aten.scatter_reduce_, aten.scatter_reduce)
 | |
| register_inplace(aten.silu_, aten.silu)
 |