mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit c35bd8d423ca53408c3aa39c2280167f3a22cea0. Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719 Approved by: https://github.com/Chillee, https://github.com/malfet
1290 lines
41 KiB
Python
1290 lines
41 KiB
Python
import torch
|
|
from torch import Tensor
|
|
from torch._decomp import register_decomposition
|
|
from enum import Enum
|
|
from typing import Tuple, Optional, List, Callable
|
|
import torch.nn.functional as F
|
|
import functools
|
|
from torch.utils._pytree import tree_map, tree_flatten
|
|
import torch._prims.utils as utils
|
|
|
|
# None of these functions are publicly accessible; get at them
|
|
# from torch._decomps
|
|
__all__: List[str] = []
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
class Reduction(Enum):
|
|
NONE = 0
|
|
MEAN = 1
|
|
SUM = 2
|
|
|
|
|
|
# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
|
|
# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
|
|
# Will need to validate the non-elementwise uses
|
|
def type_casts(f: Callable, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND):
|
|
@functools.wraps(f)
|
|
def inner(*args, **kwargs):
|
|
flat_args = [x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)]
|
|
computation_dtype, result_dtype = utils.elementwise_dtypes(*flat_args,
|
|
type_promotion_kind=type_promotion)
|
|
|
|
# TODO: pretty sure this is not quite right
|
|
def increase_prec(x):
|
|
if isinstance(x, Tensor):
|
|
return x.to(computation_dtype)
|
|
else:
|
|
return x
|
|
|
|
def decrease_prec(x):
|
|
if isinstance(x, Tensor):
|
|
return x.to(result_dtype)
|
|
else:
|
|
return x
|
|
|
|
r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
|
|
return tree_map(decrease_prec, r)
|
|
|
|
return inner
|
|
|
|
pw_cast_for_opmath = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
|
|
reduction_complex_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
|
|
pw_cast_for_int_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
|
|
|
|
# This expands x until x.dim() == dim. Might be useful as an operator
|
|
def _unsqueeze_to_dim(x: Tensor, dim: int):
|
|
for _ in range(dim - x.dim()):
|
|
x = x.unsqueeze(-1)
|
|
return x
|
|
|
|
|
|
@register_decomposition(aten.tanh_backward)
|
|
@pw_cast_for_opmath
|
|
def tanh_backward(out_grad: Tensor, y: Tensor):
|
|
return out_grad * (1 - y * y).conj_physical()
|
|
|
|
|
|
@register_decomposition(aten.sigmoid_backward)
|
|
@pw_cast_for_opmath
|
|
def sigmoid_backward(out_grad: Tensor, y: Tensor):
|
|
return out_grad * (y * (1 - y)).conj_physical()
|
|
|
|
|
|
@register_decomposition(aten.softplus_backward)
|
|
@pw_cast_for_opmath
|
|
def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
|
|
z = (x * beta).exp()
|
|
return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
|
|
|
|
|
|
@register_decomposition(aten.elu)
|
|
@pw_cast_for_opmath
|
|
def elu(
|
|
self: Tensor, alpha: float = 1, scale: float = 1, input_scale: float = 1
|
|
) -> Tensor:
|
|
negcoef = alpha * scale
|
|
poscoef = scale
|
|
negiptcoef = input_scale
|
|
return torch.where(
|
|
self > 0, self * poscoef, (torch.exp(self * negiptcoef) - 1) * negcoef
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.elu_backward)
|
|
@pw_cast_for_opmath
|
|
def elu_backward(
|
|
grad_output: Tensor,
|
|
alpha: float,
|
|
scale: float,
|
|
input_scale: float,
|
|
is_result: bool,
|
|
self_or_result: Tensor,
|
|
):
|
|
negcoef = alpha * scale
|
|
poscoef = scale
|
|
negiptcoef = input_scale
|
|
if is_result:
|
|
return torch.where(
|
|
self_or_result <= 0,
|
|
grad_output * negiptcoef * (self_or_result + negcoef),
|
|
self_or_result * poscoef,
|
|
)
|
|
else:
|
|
return torch.where(
|
|
self_or_result <= 0,
|
|
grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
|
|
grad_output * poscoef,
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.hardsigmoid)
|
|
@pw_cast_for_opmath
|
|
def hardsigmoid(self: Tensor) -> Tensor:
|
|
return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
|
|
|
|
|
|
@register_decomposition(aten.hardsigmoid_backward)
|
|
@pw_cast_for_opmath
|
|
def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
|
|
return torch.where(
|
|
(self > -3.0) & (self < 3.0),
|
|
grad_output * (1.0 / 6.0),
|
|
grad_output.new_zeros(()),
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.hardtanh)
|
|
@pw_cast_for_opmath
|
|
def hardtanh(self: Tensor, min_val: float = -1, max_val: float = 1) -> Tensor:
|
|
return torch.clamp(self, min_val, max_val)
|
|
|
|
|
|
@register_decomposition(aten.hardtanh_backward)
|
|
@pw_cast_for_opmath
|
|
def hardtanh_backward(
|
|
grad_output: Tensor, self: Tensor, min_val: float, max_val: float
|
|
):
|
|
return torch.where(
|
|
(self <= min_val) | (self >= max_val), grad_output.new_zeros(()), grad_output
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.hardshrink_backward)
|
|
@pw_cast_for_opmath
|
|
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
|
|
return torch.where(
|
|
(self >= -lambd) & (self <= lambd), grad_out.new_zeros(()), grad_out
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.hardswish)
|
|
@pw_cast_for_opmath
|
|
def hardswish(self: Tensor) -> Tensor:
|
|
return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
|
|
|
|
|
|
@register_decomposition(aten.hardswish_backward)
|
|
@pw_cast_for_opmath
|
|
def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
|
|
return torch.where(
|
|
self < -3,
|
|
grad_output.new_zeros(()),
|
|
torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.threshold_backward)
|
|
@pw_cast_for_opmath
|
|
def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
|
|
return torch.where(self <= threshold, grad_output.new_zeros(()), grad_output)
|
|
|
|
|
|
@register_decomposition(aten.leaky_relu)
|
|
@pw_cast_for_opmath
|
|
def leaky_relu(self: Tensor, negative_slope: float = 0.01) -> Tensor:
|
|
return torch.where(self > 0, self, self * negative_slope)
|
|
|
|
|
|
@register_decomposition(aten.leaky_relu_backward)
|
|
@pw_cast_for_opmath
|
|
def leaky_relu_backward(
|
|
grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
|
|
):
|
|
return torch.where(self > 0, grad_output, grad_output * negative_slope)
|
|
|
|
|
|
|
|
@register_decomposition(aten.gelu)
|
|
@pw_cast_for_opmath
|
|
def gelu(self: Tensor, approximate: str = 'none') -> Tensor:
|
|
M_SQRT2 = 1.41421356237309504880
|
|
M_SQRT1_2 = 0.70710678118654752440
|
|
M_2_SQRTPI = 1.12837916709551257390
|
|
if approximate == 'tanh':
|
|
kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
|
|
kKappa = 0.044715
|
|
x_cube = self * self * self
|
|
inner = kBeta * (self + kKappa * x_cube)
|
|
return 0.5 * self * (1 + torch.tanh(inner))
|
|
else:
|
|
kAlpha = M_SQRT1_2
|
|
return self * 0.5 * (1 + torch.erf(self * kAlpha))
|
|
|
|
|
|
@register_decomposition(aten.gelu_backward)
|
|
@pw_cast_for_opmath
|
|
def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
|
|
M_SQRT2 = 1.41421356237309504880
|
|
M_SQRT1_2 = 0.70710678118654752440
|
|
M_2_SQRTPI = 1.12837916709551257390
|
|
if approximate == 'tanh':
|
|
kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
|
|
kKappa = 0.044715
|
|
x_sq = self * self
|
|
x_cube = x_sq * self
|
|
inner = kBeta * (self + kKappa * x_cube)
|
|
tanh_inner = torch.tanh(inner)
|
|
|
|
left = 0.5 * self
|
|
right = 1 + tanh_inner
|
|
|
|
left_derivative = 0.5 * right
|
|
|
|
tanh_derivative = 1 - tanh_inner * tanh_inner
|
|
inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
|
|
right_derivative = left * tanh_derivative * inner_derivative
|
|
|
|
return grad * (left_derivative + right_derivative)
|
|
else:
|
|
kAlpha = M_SQRT1_2
|
|
kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
|
|
cdf = 0.5 * (1 + torch.erf(self * kAlpha))
|
|
pdf = kBeta * torch.exp(self * self * -0.5)
|
|
return grad * (cdf + self * pdf)
|
|
|
|
|
|
@register_decomposition(aten.mish_backward)
|
|
@pw_cast_for_opmath
|
|
def mish_backward(grad_output: Tensor, input: Tensor):
|
|
input_tanh_softplus = torch.tanh(F.softplus(input))
|
|
input_sigmoid = torch.sigmoid(input)
|
|
out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
|
|
return grad_output * (input_tanh_softplus + out)
|
|
|
|
|
|
@register_decomposition(aten.silu)
|
|
@pw_cast_for_opmath
|
|
def silu(self: Tensor) -> Tensor:
|
|
return self * torch.sigmoid(self)
|
|
|
|
|
|
@register_decomposition(aten.silu_backward)
|
|
@pw_cast_for_opmath
|
|
def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
|
|
sigmoid = 1 / (1 + torch.exp(-self))
|
|
return grad_output * sigmoid * (1 + self * (1 - sigmoid))
|
|
|
|
|
|
@register_decomposition(aten.softshrink_backward)
|
|
def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
|
|
return torch.where(
|
|
(self >= -lambd) & (self <= lambd), grad_output.new_zeros(()), grad_output
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.prelu_backward)
|
|
@pw_cast_for_opmath
|
|
def prelu_backward(
|
|
grad_output: Tensor, self: Tensor, weight: Tensor
|
|
) -> Tuple[Tensor, Tensor]:
|
|
# Logic is more complicated than I would like. Basically, weight can either
|
|
# be a scalar or a vector of size [C], and in the forward pass it's
|
|
# broadcast against [N, C, ...]. So now, we need to do the corresponding
|
|
# reduction, which is harder than we'd like...
|
|
cur_weight = weight
|
|
for _ in range(2, grad_output.dim()):
|
|
cur_weight = cur_weight.unsqueeze(-1)
|
|
input_grad = torch.where(self > 0, grad_output, cur_weight * grad_output)
|
|
weight_grad_collector = torch.where(
|
|
self > 0, grad_output.new_zeros(()), self * grad_output
|
|
)
|
|
out = weight_grad_collector.sum_to_size(cur_weight.shape)
|
|
while out.dim() > weight.dim():
|
|
out = out.squeeze(-1)
|
|
return (input_grad, out)
|
|
|
|
|
|
@register_decomposition(aten.rrelu_with_noise_backward)
|
|
@pw_cast_for_opmath
|
|
def rrelu_with_noise_backward(
|
|
grad_output: Tensor,
|
|
self: Tensor,
|
|
noise: Tensor,
|
|
lower: float,
|
|
upper: float,
|
|
training: bool,
|
|
self_is_result: bool,
|
|
) -> Tensor:
|
|
if training and upper - lower > 1e-6:
|
|
return grad_output.mul(noise)
|
|
else:
|
|
negative_slope = (lower + upper) / 2
|
|
return aten.leaky_relu_backward(grad_output, self, negative_slope, self_is_result)
|
|
|
|
|
|
@register_decomposition(aten.log_sigmoid_backward)
|
|
@pw_cast_for_opmath
|
|
def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
|
|
in_negative = self < 0
|
|
max_deriv = torch.where(in_negative, 1, 0)
|
|
sign = torch.where(in_negative, 1, -1)
|
|
z = torch.exp(-torch.abs(self))
|
|
return grad_output * (max_deriv - sign * (z / (1 + z)))
|
|
# CPU has a special formula that uses buffer, but disabled for convenience sake
|
|
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
|
|
|
|
|
|
def apply_loss_reduction(loss: Tensor, reduction: int):
|
|
if reduction == Reduction.MEAN.value:
|
|
return torch.mean(loss)
|
|
elif reduction == Reduction.SUM.value:
|
|
return torch.sum(loss)
|
|
else:
|
|
return loss
|
|
|
|
|
|
def to_real_dtype(dtype: torch.dtype):
|
|
if dtype == torch.complex32:
|
|
return torch.float16
|
|
elif dtype == torch.complex64:
|
|
return torch.float32
|
|
elif dtype == torch.complex128:
|
|
return torch.float64
|
|
|
|
# TODO: None of these loss castings are quite correct, see
|
|
# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
|
|
# perform the pointwise portion in opmath, but don't maintain it between the
|
|
# pointwise portion and the reduction
|
|
|
|
@register_decomposition(aten.l1_loss)
|
|
def l1_loss(
|
|
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
|
|
) -> Tensor:
|
|
loss = (self - target).abs()
|
|
# PyTorch semantics result in the output of l1_loss having the corresponding
|
|
# real dtype to self. This may not happen without explicit casting if say
|
|
# self: complex64 and target: float64, which results in loss: float64
|
|
float_type = to_real_dtype(self.dtype)
|
|
return apply_loss_reduction(loss, reduction).to(float_type)
|
|
|
|
|
|
@register_decomposition(aten.l1_loss_backward)
|
|
@pw_cast_for_opmath
|
|
def l1_loss_backward(
|
|
grad_output: Tensor,
|
|
self: Tensor,
|
|
target: Tensor,
|
|
reduction: int = Reduction.MEAN.value,
|
|
):
|
|
sign = torch.sign(self - target)
|
|
|
|
norm = sign / self.numel() if reduction == Reduction.MEAN.value else sign
|
|
return grad_output * norm
|
|
|
|
|
|
@register_decomposition(aten.mse_loss)
|
|
@pw_cast_for_opmath
|
|
def mse_loss(
|
|
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
|
|
) -> Tensor:
|
|
loss = (self - target) ** 2
|
|
return apply_loss_reduction(loss, reduction)
|
|
|
|
|
|
@register_decomposition(aten.mse_loss_backward)
|
|
@pw_cast_for_opmath
|
|
def mse_loss_backward(
|
|
grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
|
|
):
|
|
norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
|
|
return norm * (input - target) * grad_output
|
|
|
|
|
|
@register_decomposition(aten.huber_loss)
|
|
@pw_cast_for_opmath
|
|
def huber_loss(
|
|
self: Tensor,
|
|
target: Tensor,
|
|
reduction: int = Reduction.MEAN.value,
|
|
delta: float = 1.0,
|
|
) -> Tensor:
|
|
assert delta > 0, "huber_loss does not support non-positive values for delta."
|
|
z = (self - target).abs()
|
|
loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
|
|
return apply_loss_reduction(loss, reduction)
|
|
|
|
|
|
@register_decomposition(aten.huber_loss_backward)
|
|
@pw_cast_for_opmath
|
|
def huber_loss_backward(
|
|
grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
|
|
):
|
|
norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
|
|
x = self - target
|
|
return torch.where(
|
|
x < -delta,
|
|
-norm * grad_output * delta,
|
|
torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
|
|
)
|
|
|
|
|
|
def _nll_loss_backward(
|
|
grad_output: Tensor,
|
|
self: Tensor,
|
|
target: Tensor,
|
|
weight: Optional[Tensor],
|
|
reduction: int,
|
|
ignore_index: int,
|
|
total_weight: Tensor,
|
|
) -> Tensor:
|
|
channel_dim = 0 if self.dim() < 2 else 1
|
|
if reduction == Reduction.MEAN.value:
|
|
grad_output = grad_output / total_weight
|
|
|
|
target = target.unsqueeze(channel_dim)
|
|
grad_input = torch.zeros_like(self)
|
|
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
|
|
|
|
if grad_input.dim() > grad_output.dim() > 0:
|
|
grad_output = grad_output.unsqueeze(channel_dim)
|
|
|
|
if weight is not None:
|
|
new_shape = [1 for _ in range(self.dim())]
|
|
new_shape[channel_dim] = weight.shape[0]
|
|
weight = weight.reshape(new_shape)
|
|
grad_output = grad_output * weight
|
|
|
|
has_ignore_index = ignore_index >= 0
|
|
if has_ignore_index:
|
|
ignore_index_mask = target != ignore_index
|
|
grad_output = grad_output * ignore_index_mask
|
|
|
|
return grad_input * grad_output
|
|
|
|
@register_decomposition(aten.nll_loss_backward)
|
|
def nll_loss_backward(
|
|
grad_output: Tensor,
|
|
self: Tensor,
|
|
target: Tensor,
|
|
weight: Optional[Tensor],
|
|
reduction: int,
|
|
ignore_index: int,
|
|
total_weight: Tensor,
|
|
) -> Tensor:
|
|
assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
|
|
assert (
|
|
target.dim() <= 1
|
|
), "0D or 1D target tensor expected, multi-target not supported"
|
|
|
|
no_batch_dim = self.dim() == 1 and target.dim() == 0
|
|
assert no_batch_dim or (
|
|
self.shape[0] == target.shape[0]
|
|
), f"size mismatch (got input: {self.shape}, target: {target.shape})"
|
|
assert total_weight.numel() == 1, (
|
|
"expected total_weight to be a single element tensor, got: ",
|
|
f"{total_weight.shape} ({total_weight.numel()} elements)",
|
|
)
|
|
|
|
assert (
|
|
weight is None or weight.numel() == self.shape[-1]
|
|
), "weight tensor should be defined either for all or no classes"
|
|
|
|
if reduction == Reduction.NONE.value and self.dim() == 2:
|
|
assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
|
|
f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
|
|
f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
|
|
)
|
|
else:
|
|
assert (
|
|
grad_output.dim() <= 1 and grad_output.numel() == 1
|
|
), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
|
|
|
|
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
|
|
|
|
|
|
@register_decomposition(aten.nll_loss2d_backward)
|
|
def nll_loss2d_backward(
|
|
grad_output: Tensor,
|
|
self: Tensor,
|
|
target: Tensor,
|
|
weight: Optional[Tensor],
|
|
reduction: int,
|
|
ignore_index: int,
|
|
total_weight: Tensor,
|
|
) -> Tensor:
|
|
assert (
|
|
self.dim() == 4
|
|
), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
|
|
|
|
assert (
|
|
target.dim() == 3
|
|
), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
|
|
|
|
assert(
|
|
self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2]
|
|
), f"size mismatch (got input: {self.shape}, target: {target.shape}"
|
|
|
|
assert (
|
|
total_weight.numel() == 1
|
|
), (
|
|
"expected total_weight to be a single element tensor, "
|
|
f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
|
|
)
|
|
|
|
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
|
|
|
|
|
|
@register_decomposition(aten.binary_cross_entropy)
|
|
@pw_cast_for_opmath
|
|
def binary_cross_entropy(
|
|
self: Tensor,
|
|
target: Tensor,
|
|
weight: Optional[Tensor] = None,
|
|
reduction: int = Reduction.MEAN.value,
|
|
) -> Tensor:
|
|
# We cannot currently model this without introducing data-dependent control flow
|
|
# TORCH_CHECK(
|
|
# (input_val >= 0) && (input_val <= 1),
|
|
# "all elements of input should be between 0 and 1"
|
|
# )
|
|
loss = (target - 1) * torch.maximum(
|
|
torch.log(1 - self), self.new_full((), -100)
|
|
) - target * torch.maximum(torch.log(self), self.new_full((), -100))
|
|
if weight is not None:
|
|
loss = loss * weight
|
|
return apply_loss_reduction(loss, reduction)
|
|
|
|
|
|
@register_decomposition(aten.binary_cross_entropy_backward)
|
|
@pw_cast_for_opmath
|
|
def binary_cross_entropy_backward(
|
|
grad_output: Tensor,
|
|
self: Tensor,
|
|
target: Tensor,
|
|
weight: Optional[Tensor] = None,
|
|
reduction: int = Reduction.MEAN.value,
|
|
) -> Tensor:
|
|
EPSILON = 1e-12
|
|
result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
|
|
if weight is not None:
|
|
result = result * weight
|
|
if reduction == Reduction.MEAN.value:
|
|
result = result / self.numel()
|
|
return result
|
|
|
|
|
|
@register_decomposition(aten._euclidean_dist)
|
|
def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
|
|
x1_norm = x1.pow(2).sum(-1, True)
|
|
x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
|
|
x2_norm = x2.pow(2).sum(-1, True)
|
|
x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
|
|
x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
|
|
x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
|
|
result = x1_.matmul(x2_.mT)
|
|
return result.clamp_min(0).sqrt()
|
|
|
|
|
|
@register_decomposition(aten.slice_backward)
|
|
def slice_backward(
|
|
grad_output: Tensor,
|
|
input_sizes: List[int],
|
|
dim: int,
|
|
start: int,
|
|
end: int,
|
|
step: int,
|
|
):
|
|
grad_input = grad_output.new_zeros(input_sizes)
|
|
return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
|
|
|
|
|
|
@register_decomposition(aten.select_backward)
|
|
def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
|
|
grad_input = grad_output.new_zeros(input_sizes)
|
|
return torch.select_scatter(grad_input, grad_output, dim, index)
|
|
|
|
|
|
@register_decomposition(aten.diagonal_backward)
|
|
def diagonal_backward(
|
|
grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
|
|
):
|
|
grad_input = grad_output.new_zeros(input_sizes)
|
|
return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
|
|
|
|
|
|
@register_decomposition(aten._softmax_backward_data)
|
|
@pw_cast_for_opmath
|
|
def _softmax_backward_data(
|
|
grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
|
|
):
|
|
new_grad = grad_output * output
|
|
return new_grad - output * torch.sum(new_grad, dim=dim, keepdim=True)
|
|
|
|
|
|
@register_decomposition(aten._log_softmax_backward_data)
|
|
@pw_cast_for_opmath
|
|
def _log_softmax_backward_data(
|
|
grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
|
|
):
|
|
grad_input = grad_output - torch.exp(output) * torch.sum(
|
|
grad_output, dim=dim, keepdim=True
|
|
)
|
|
return grad_input
|
|
|
|
|
|
# TODO: the type annotations on arguments are not quite right
|
|
|
|
|
|
@register_decomposition(aten.im2col_backward)
|
|
def im2col_backward(
|
|
grad_output: Tensor,
|
|
input_size: List[int],
|
|
kernel_size: List[int],
|
|
dilation: List[int],
|
|
padding: List[int],
|
|
stride: List[int],
|
|
) -> Tensor:
|
|
return F.fold(grad_output, input_size, kernel_size, dilation, padding, stride) # type: ignore[arg-type]
|
|
|
|
|
|
@register_decomposition(aten.col2im_backward)
|
|
def col2im_backward(
|
|
grad_output: Tensor,
|
|
kernel_size: List[int],
|
|
dilation: List[int],
|
|
padding: List[int],
|
|
stride: List[int],
|
|
) -> Tensor:
|
|
return F.unfold(grad_output, kernel_size, dilation, padding, stride) # type: ignore[arg-type]
|
|
|
|
|
|
@register_decomposition(aten.masked_fill.Scalar)
|
|
def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor:
|
|
return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self)
|
|
|
|
|
|
@register_decomposition(aten.masked_fill.Tensor)
|
|
def masked_fill_Tensor(self: Tensor, mask: Tensor, value: Tensor) -> Tensor:
|
|
return torch.where(mask, value, self)
|
|
|
|
|
|
@register_decomposition(aten.native_dropout_backward)
|
|
@pw_cast_for_opmath
|
|
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
|
|
return grad_output * (mask.type_as(grad_output) * scale)
|
|
|
|
|
|
@register_decomposition(aten.logit)
|
|
@pw_cast_for_int_to_real
|
|
def logit(self: Tensor, eps: Optional[float] = None) -> Tensor:
|
|
if eps is None:
|
|
eps = -1.0
|
|
lo = eps
|
|
hi = 1 - eps
|
|
self = torch.clamp(self, lo, hi)
|
|
return (self / (1 - self)).log()
|
|
|
|
|
|
@register_decomposition(aten.logit_backward)
|
|
@pw_cast_for_opmath
|
|
def logit_backward(
|
|
grad_output: Tensor, self: Tensor, eps: Optional[float] = None
|
|
) -> Tensor:
|
|
if eps is not None:
|
|
lo = eps
|
|
hi = 1.0 - lo
|
|
return torch.where(
|
|
torch.logical_and(self >= lo, self <= hi),
|
|
grad_output / (self * (1.0 - self)),
|
|
self.new_zeros(()),
|
|
)
|
|
else:
|
|
return torch.where(
|
|
torch.logical_and(self >= 0.0, self <= 1.0),
|
|
grad_output / (self * (1.0 - self)),
|
|
self.new_full((), float("nan")),
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.native_dropout)
|
|
@pw_cast_for_opmath
|
|
def native_dropout(input: Tensor, p: float, train: Optional[bool]):
|
|
if train:
|
|
bool_mask = torch.rand_like(input) < p
|
|
res = bool_mask * input * float(1.0 / p)
|
|
return (res, bool_mask)
|
|
else:
|
|
return (input, torch.ones_like(input, dtype=torch.bool))
|
|
|
|
|
|
# TODO: Correct the type promotion semantics
|
|
@register_decomposition(aten._softmax)
|
|
@pw_cast_for_opmath
|
|
def _softmax(x: Tensor, dim: int, half_to_float: bool):
|
|
x_max = torch.max(x, dim, keepdim=True)[0]
|
|
unnormalized = torch.exp(x - x_max)
|
|
return unnormalized / torch.sum(unnormalized, dim, keepdim=True)
|
|
|
|
|
|
# TODO: Correct the type promotion semantics
|
|
@register_decomposition(aten._log_softmax)
|
|
@pw_cast_for_opmath
|
|
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
|
|
x_max = torch.max(x, dim, keepdim=True)[0]
|
|
shifted = x - x_max
|
|
shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
|
|
return shifted - shifted_logsumexp
|
|
|
|
|
|
@register_decomposition(aten.addcdiv)
|
|
@pw_cast_for_opmath
|
|
def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
|
|
return self + value * (tensor1 / tensor2)
|
|
|
|
|
|
# Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
|
|
@register_decomposition(aten.addcmul)
|
|
@pw_cast_for_opmath
|
|
def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
|
|
if self.is_floating_point() or self.is_complex():
|
|
return self + value * tensor1 * tensor2
|
|
else:
|
|
return self + int(value) * tensor1 * tensor2
|
|
|
|
|
|
@register_decomposition(aten.rsub.Tensor)
|
|
def rsub_Tensor(self: Tensor, other: Tensor, alpha: float = 1) -> Tensor:
|
|
return torch.sub(other, self, alpha=alpha)
|
|
|
|
|
|
@register_decomposition(aten.rsub.Scalar)
|
|
def rsub_Scalar(self: Tensor, other: float, alpha: float = 1) -> Tensor:
|
|
return torch.sub(other, self, alpha=alpha)
|
|
|
|
|
|
@register_decomposition(aten.embedding)
|
|
def embedding(
|
|
weight: Tensor,
|
|
indices: Tensor,
|
|
padding_idx: int = -1,
|
|
scale_grad_by_freq: bool = False,
|
|
sparse: bool = False,
|
|
) -> Tensor:
|
|
assert weight.dim() == 2, "'weight' must be 2-D"
|
|
# TODO: Assert not ported over yet
|
|
# auto indices_arg = TensorArg(indices, "indices", 1);
|
|
# checkScalarTypes("embedding", indices_arg, {kLong, kInt});
|
|
|
|
if indices.dim() == 1:
|
|
return weight.index_select(0, indices)
|
|
|
|
size = list(indices.shape)
|
|
for d in weight.shape[1:]:
|
|
size.append(d)
|
|
|
|
return weight.index_select(0, indices.reshape(-1)).view(size)
|
|
|
|
# TODO: Correct the type promotion semantics
|
|
@register_decomposition(aten.embedding_dense_backward)
|
|
def embedding_dense_backward(
|
|
grad_output: Tensor,
|
|
indices: Tensor,
|
|
num_weights: int,
|
|
padding_idx: int,
|
|
scale_grad_by_freq: bool,
|
|
):
|
|
numel = indices.numel()
|
|
grad = grad_output.view(numel, grad_output.size(-1))
|
|
grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1]))
|
|
indices_rank1 = indices.view(numel)
|
|
if scale_grad_by_freq:
|
|
counts = indices.new_zeros((num_weights,))
|
|
ones = indices.new_ones((numel,))
|
|
counts = counts.index_put([indices_rank1], ones, accumulate=True)
|
|
grad_weights_scale = counts[indices_rank1]
|
|
grad = grad / grad_weights_scale.unsqueeze(1)
|
|
skip_padding = (indices_rank1 != padding_idx).unsqueeze(1)
|
|
skip_padding = skip_padding.expand_as(grad)
|
|
zero_grad = torch.full_like(grad, 0)
|
|
return grad_weight.index_put(
|
|
[indices_rank1], torch.where(skip_padding, grad, zero_grad), accumulate=True
|
|
)
|
|
|
|
|
|
def prod(x: List[int]):
|
|
r = 1
|
|
for i in x:
|
|
r *= i
|
|
return r
|
|
|
|
|
|
@register_decomposition(aten.split_with_sizes)
|
|
def split_with_sizes(
|
|
self: Tensor, split_sizes: List[int], dim: int = 0
|
|
) -> List[Tensor]:
|
|
num_splits = len(split_sizes)
|
|
splits = []
|
|
start_idx = 0
|
|
for i in range(num_splits):
|
|
length = split_sizes[i]
|
|
splits.append(self.narrow(dim, start_idx, length))
|
|
start_idx += length
|
|
return splits
|
|
|
|
|
|
@register_decomposition(aten.split.Tensor)
|
|
def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
|
|
input_sizes = self.shape
|
|
dim_size = input_sizes[dim]
|
|
if split_size == 0:
|
|
assert dim_size == 0
|
|
return [self]
|
|
chunks = (dim_size + split_size - 1) // split_size
|
|
split_sizes = [split_size for i in range(chunks)]
|
|
split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
|
|
return torch.split(self, split_sizes, dim)
|
|
|
|
|
|
# TODO: this doesn't appear to have enough precision in bfloat16
|
|
@register_decomposition(aten.addmm)
|
|
@pw_cast_for_opmath
|
|
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
|
|
if not self.is_floating_point() and not self.is_complex():
|
|
beta = int(beta)
|
|
alpha = int(alpha)
|
|
out = alpha * torch.mm(mat1, mat2)
|
|
if beta == 0:
|
|
return out
|
|
return beta * self + out
|
|
|
|
|
|
# TODO: Correct the type promotion semantics
|
|
@register_decomposition(aten.native_layer_norm)
|
|
@pw_cast_for_opmath
|
|
def native_layer_norm(
|
|
input: Tensor,
|
|
normalized_shape: List[int],
|
|
weight: Optional[Tensor],
|
|
bias: Optional[Tensor],
|
|
eps: float,
|
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
input_shape = input.shape
|
|
input_ndim = input.dim()
|
|
|
|
axis = input_ndim - len(normalized_shape)
|
|
M = prod(input_shape[:axis]) # type: ignore[arg-type]
|
|
|
|
# Hmm... not sure how I get around this...
|
|
# Basically, native_batch_norm doesn't support 0-entry tensors, while
|
|
# native_layer_norm does (and is tested by OpInfos!)
|
|
if M > 0:
|
|
input_reshaped = input.view(1, M, -1)
|
|
else:
|
|
return (input, input.new_zeros((0,)), input.new_zeros((0,)))
|
|
|
|
# Unlike Batch Normalization, which applies scalar scale and bias for each
|
|
# entire channel/plane with the affine option, Layer Normalization applies
|
|
# per-element scale and bias. E.g. For input {N, C, H, W}, weight for
|
|
# batchnorm has shape {C} while weight for layernorm has shape {H, W} or {W}.
|
|
out, mean, rstd = aten.native_batch_norm(
|
|
input_reshaped,
|
|
weight=None,
|
|
bias=None,
|
|
running_mean=None,
|
|
running_var=None,
|
|
training=True,
|
|
momentum=0.0,
|
|
eps=eps,
|
|
)
|
|
out = out.view(input_shape)
|
|
if weight is not None:
|
|
out = out * weight
|
|
if bias is not None:
|
|
out = out + bias
|
|
|
|
stat_shape = list(input_shape[:axis])
|
|
for _ in range(axis, input.dim()):
|
|
stat_shape.append(1)
|
|
mean = mean.view(stat_shape)
|
|
rstd = rstd.view(stat_shape)
|
|
return (out, mean, rstd)
|
|
|
|
|
|
# TODO: Correct the type promotion semantics
|
|
@register_decomposition(aten.native_layer_norm_backward)
|
|
@pw_cast_for_opmath
|
|
def native_layer_norm_backward(
|
|
grad_out: Tensor,
|
|
input: Tensor,
|
|
normalized_shape: List[int],
|
|
mean: Tensor,
|
|
rstd: Tensor,
|
|
weight: Optional[Tensor],
|
|
bias: Optional[Tensor],
|
|
output_mask: List[bool],
|
|
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
|
input_shape = input.shape
|
|
input_ndim = input.dim()
|
|
|
|
axis = input_ndim - len(normalized_shape)
|
|
inner_dims = input_shape[axis:]
|
|
outer_dims = input_shape[:axis]
|
|
inner_dim_indices: List[int] = []
|
|
outer_dim_indices: List[int] = []
|
|
for i in range(input_ndim):
|
|
if i >= axis:
|
|
inner_dim_indices.append(i)
|
|
else:
|
|
outer_dim_indices.append(i)
|
|
|
|
N = prod(inner_dims) # type: ignore[arg-type]
|
|
M = prod(outer_dims) # type: ignore[arg-type]
|
|
if M <= 0 or N <= 0:
|
|
return (
|
|
input.new_zeros(input_shape),
|
|
input.new_zeros(input_shape[axis:]),
|
|
input.new_zeros(input_shape[axis:]),
|
|
)
|
|
|
|
x_hat = (input - mean) * rstd
|
|
if weight is not None:
|
|
grad_x_hat = grad_out * weight
|
|
else:
|
|
grad_x_hat = grad_out
|
|
a = grad_x_hat * N
|
|
b = torch.sum(grad_x_hat, inner_dim_indices, True)
|
|
c1 = torch.mul(grad_x_hat, x_hat)
|
|
c2 = torch.sum(c1, inner_dim_indices, True)
|
|
c3 = torch.mul(x_hat, c2)
|
|
|
|
inner = a - b - c3
|
|
|
|
if output_mask[0]:
|
|
d_input: Optional[Tensor] = (rstd / N) * inner
|
|
else:
|
|
d_input = None
|
|
|
|
if output_mask[1] and weight is not None:
|
|
if len(outer_dim_indices) > 0:
|
|
d_weight: Optional[Tensor] = torch.sum(
|
|
grad_out * x_hat, outer_dim_indices, False
|
|
)
|
|
else:
|
|
d_weight = grad_out * x_hat
|
|
else:
|
|
d_weight = None
|
|
|
|
if output_mask[2] and bias is not None:
|
|
if len(outer_dim_indices) > 0:
|
|
d_bias: Optional[Tensor] = torch.sum(grad_out, outer_dim_indices, False)
|
|
else:
|
|
d_bias = grad_out
|
|
else:
|
|
d_bias = None
|
|
return (d_input, d_weight, d_bias)
|
|
|
|
|
|
# TODO: Correct the type promotion semantics
|
|
@register_decomposition(aten.native_batch_norm)
|
|
@pw_cast_for_opmath
|
|
def native_batch_norm(
|
|
input: Tensor,
|
|
weight: Optional[Tensor],
|
|
bias: Optional[Tensor],
|
|
running_mean: Optional[Tensor],
|
|
running_var: Optional[Tensor],
|
|
training: bool,
|
|
momentum: float,
|
|
eps: float,
|
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
reduction_dims = [0] + list(range(2, input.dim()))
|
|
if training:
|
|
# save_mean = torch.sum(input / (input.shape[0] * input.shape[2]), dim=reduction_dims)
|
|
biased_var, save_mean = torch.var_mean(
|
|
input, dim=reduction_dims, unbiased=False
|
|
)
|
|
save_invstd = 1 / (torch.sqrt(biased_var + eps))
|
|
|
|
if running_mean is not None:
|
|
running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean)
|
|
if running_var is not None:
|
|
n = input.numel() / input.shape[1]
|
|
# This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
|
|
# But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
|
|
# numerics probably don't matter.
|
|
unbiased_var = biased_var * (n / (n - 1))
|
|
running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var)
|
|
mean = save_mean
|
|
invstd = save_invstd
|
|
else:
|
|
assert running_mean is not None and running_var is not None
|
|
mean = running_mean
|
|
invstd = 1 / (torch.sqrt(running_var + eps))
|
|
# Very annoying inconsistency where CPU and CUDA give different shapes
|
|
if input.device.type == "cuda":
|
|
save_mean = running_mean
|
|
save_invstd = invstd
|
|
else:
|
|
save_mean = input.new_zeros((0,))
|
|
save_invstd = input.new_zeros((0,))
|
|
|
|
if weight is None:
|
|
weight = input.new_ones(())
|
|
|
|
if bias is None:
|
|
bias = input.new_zeros(())
|
|
|
|
mean = _unsqueeze_to_dim(mean, input.dim() - 1)
|
|
invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
|
|
weight = _unsqueeze_to_dim(weight, input.dim() - 1)
|
|
bias = _unsqueeze_to_dim(bias, input.dim() - 1)
|
|
output = ((input - mean) * invstd) * weight + bias
|
|
return output, save_mean, save_invstd
|
|
|
|
|
|
@register_decomposition(aten.clamp_min)
|
|
def clamp_min(self: Tensor, min: float):
|
|
return torch.clamp(self, min=min)
|
|
|
|
|
|
@register_decomposition(aten.clamp_max)
|
|
def clamp_max(self: Tensor, max: float):
|
|
return torch.clamp(self, max=max)
|
|
|
|
|
|
@register_decomposition(aten._fused_dropout)
|
|
@pw_cast_for_opmath
|
|
def _fused_dropout_decomposition(input, p, generator=None):
|
|
mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
|
|
res = mask.type_as(input) * input * (1.0 / p)
|
|
return (res, mask)
|
|
|
|
|
|
# TODO: these logical decomps are buggy for complex inputs
|
|
@register_decomposition(aten.logical_xor)
|
|
def logical_xor(self: Tensor, other: Tensor) -> Tensor:
|
|
return self.to(dtype=torch.bool) ^ other.to(dtype=torch.bool)
|
|
|
|
|
|
@register_decomposition(aten.logical_not)
|
|
def logical_not(self: Tensor) -> Tensor:
|
|
return ~self.to(dtype=torch.bool)
|
|
|
|
|
|
@register_decomposition(aten.xlogy.Tensor)
|
|
@pw_cast_for_int_to_real
|
|
def xlogy(self: Tensor, other: Tensor) -> Tensor:
|
|
return aten.where(aten.isnan(self),
|
|
self,
|
|
aten.where(self == aten.new_zeros(self, ()),
|
|
aten.new_zeros(self, ()),
|
|
self * aten.log(other)))
|
|
|
|
|
|
@register_decomposition(aten.var.correction)
|
|
@reduction_complex_to_real
|
|
def var_correction(
|
|
x: Tensor,
|
|
dims: Optional[List[int]],
|
|
correction: Optional[int] = None,
|
|
keepdim: bool = False,
|
|
):
|
|
if dims is None:
|
|
dims = []
|
|
|
|
if x.is_complex():
|
|
# For complex, calculate variance of real and imaginary components
|
|
# separately then add to get overall variance.
|
|
real_in = x.real
|
|
var_real = torch.var(real_in, dims, correction=correction, keepdim=keepdim)
|
|
imag_in = x.imag
|
|
var_imag = torch.var(imag_in, dims, correction=correction, keepdim=keepdim)
|
|
return var_real + var_imag
|
|
|
|
if correction is None:
|
|
correction = 0
|
|
|
|
if len(dims) == 0:
|
|
n = prod(x.shape) # type: ignore[arg-type]
|
|
else:
|
|
n = 1
|
|
for dim in dims:
|
|
n *= x.shape[dim]
|
|
|
|
mean = torch.mean(x, dims, True)
|
|
sub = x - mean
|
|
sq = sub * sub
|
|
sum = torch.sum(sq, dims, keepdim)
|
|
|
|
if correction:
|
|
n = n - correction
|
|
|
|
return sum / n
|
|
|
|
|
|
@register_decomposition(aten.std.correction)
|
|
@reduction_complex_to_real
|
|
def std_decomposition(
|
|
x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False
|
|
):
|
|
return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))
|
|
|
|
|
|
# Questionable decompositions
|
|
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
|
|
# Note that this decomposition causes issues with in-place ops
|
|
@register_decomposition(aten.detach, disable_meta=True)
|
|
def detach_decomposition(x):
|
|
return x
|
|
|
|
|
|
@register_decomposition(aten.cudnn_batch_norm)
|
|
def cudnn_batch_norm(
|
|
input: Tensor,
|
|
weight: Tensor,
|
|
bias: Optional[Tensor],
|
|
running_mean: Optional[Tensor],
|
|
running_var: Optional[Tensor],
|
|
training: bool,
|
|
exponential_average_factor: float,
|
|
epsilon: float,
|
|
):
|
|
a, b, c = aten.native_batch_norm(
|
|
input,
|
|
weight,
|
|
bias,
|
|
running_mean,
|
|
running_var,
|
|
training,
|
|
exponential_average_factor,
|
|
epsilon,
|
|
)
|
|
# Cudnn return running mean and variance when training is True
|
|
if training:
|
|
return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
|
|
return (a, input.new_zeros((0,)), input.new_zeros((0,)), input.new_zeros((0,), dtype=torch.uint8))
|
|
|
|
|
|
@register_decomposition(aten.cudnn_batch_norm_backward)
|
|
def cudnn_batch_norm_backward(
|
|
input: Tensor,
|
|
grad_output: Tensor,
|
|
weight: Tensor,
|
|
running_mean: Optional[Tensor],
|
|
running_var: Optional[Tensor],
|
|
save_mean: Optional[Tensor],
|
|
save_var: Optional[Tensor],
|
|
epsilon: float,
|
|
reserveSpace: Tensor,
|
|
):
|
|
return aten.native_batch_norm_backward(
|
|
grad_output,
|
|
input,
|
|
weight,
|
|
running_mean,
|
|
running_var,
|
|
save_mean,
|
|
save_var,
|
|
True,
|
|
epsilon,
|
|
[True, True, True],
|
|
)
|
|
|
|
|
|
@register_decomposition(aten.rot90.default)
|
|
def rot90(self: Tensor, k: int = 1, dims: List[int] = [0, 1]) -> Tensor: # noqa: B006
|
|
total_dims = self.dim()
|
|
total_rot_dims = len(dims)
|
|
assert total_rot_dims == 2, f"expected total rotation dims == 2, but got dims = {total_rot_dims}"
|
|
assert total_dims >= 2, f"expected total dims >= 2, but got total dims = {total_dims}"
|
|
assert dims[0] != dims[1] and abs(dims[0] - dims[1]) != total_dims,\
|
|
f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
|
|
assert dims[0] < total_dims and dims[0] >= -total_dims, f"Rotation dim0 out of range, dim0 = {dims[0]}"
|
|
assert dims[1] < total_dims and dims[1] >= -total_dims, f"Rotation dim1 out of range, dim1 = {dims[1]}"
|
|
k = k % 4
|
|
if k == 1:
|
|
return self.flip(dims[1]).transpose(dims[0], dims[1])
|
|
elif k == 2:
|
|
return self.flip(dims)
|
|
elif k == 3:
|
|
return self.flip(dims[0]).transpose(dims[0], dims[1])
|
|
else:
|
|
return self.clone(memory_format=torch.contiguous_format)
|
|
|
|
|
|
@register_decomposition(aten.transpose.int)
|
|
def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor:
|
|
dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc]
|
|
|
|
if self.dim() <= 1:
|
|
return self
|
|
|
|
if dim0 == dim1:
|
|
return self
|
|
perm = list(range(self.dim()))
|
|
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
|
|
return torch.permute(self, perm)
|
|
|
|
|
|
@register_decomposition(aten.t.default)
|
|
def t(self: Tensor) -> Tensor:
|
|
return self.transpose(0, 0 if self.dim() < 2 else 1)
|
|
|
|
|
|
def check_stack_inputs(tensors: List[Tensor]):
|
|
entry_shape = tensors[0].shape
|
|
for i in range(1, len(tensors)):
|
|
assert tensors[i].shape == entry_shape, (f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0"
|
|
f"and {tensors[i].shape} at entry {i}")
|
|
|
|
|
|
def get_stack_inputs(tensors: List[Tensor], dim: int):
|
|
check_stack_inputs(tensors)
|
|
return [t.unsqueeze(dim) for t in tensors]
|
|
|
|
|
|
@register_decomposition(aten.stack.default)
|
|
def stack(tensors: List[Tensor], dim: int = 0) -> Tensor:
|
|
assert len(tensors) > 0, "stack expects a non-empty TensorList"
|
|
wrapped_dim = utils.canonicalize_dim(tensors[0].dim() + 1, dim)
|
|
if wrapped_dim < tensors[0].dim() and not tensors[0].is_sparse:
|
|
check_stack_inputs(tensors)
|
|
result_sizes = list(tensors[0].shape)
|
|
result_sizes.insert(wrapped_dim, len(tensors))
|
|
out = torch.cat(tensors, wrapped_dim)
|
|
return out.view(result_sizes)
|
|
else:
|
|
return torch.cat(get_stack_inputs(tensors, wrapped_dim), dim)
|
|
|
|
|
|
def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
|
|
ndim = self.dim()
|
|
wrapped_dims = utils.canonicalize_dims(ndim, dims)
|
|
assert isinstance(wrapped_dims, tuple)
|
|
for idx in range(ndim - 1, -1, -1):
|
|
if idx in wrapped_dims:
|
|
self = self.squeeze(idx)
|
|
return self
|
|
|
|
|
|
@register_decomposition(aten.logsumexp.default)
|
|
@pw_cast_for_int_to_real
|
|
def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
|
|
if self.numel() == 0:
|
|
return torch.sum(torch.exp(self), dim, keepdim).log()
|
|
maxes = torch.amax(self, dim, keepdim=True)
|
|
maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)
|
|
maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)
|
|
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
|
|
return result.log().add(maxes_squeezed)
|
|
|
|
|
|
@register_decomposition(aten.trace.default)
|
|
def trace(self: Tensor) -> Tensor:
|
|
return torch.sum(torch.diag(self))
|
|
|
|
|
|
# nb: Should use acc_t, not op_math
|
|
@register_decomposition(aten.log_sigmoid_forward.default)
|
|
@pw_cast_for_opmath
|
|
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
|
|
min = torch.minimum(self.new_zeros(()), self)
|
|
z = torch.exp(-torch.abs(self))
|
|
if self.is_cuda:
|
|
buffer = self.new_zeros((0,))
|
|
else:
|
|
buffer = z
|
|
return min - torch.log1p(z), buffer
|