mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Added softshrink_backward, prelu_backward, and col2im_backward decompositions
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Tuple
|
||||
from enum import Enum
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -116,7 +116,20 @@ def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
|
||||
sigmoid = 1 / (1 + aten.exp(aten.neg(self)))
|
||||
return grad_output * sigmoid * (1 + self * (1 - sigmoid))
|
||||
|
||||
# whyyyy does log_sigmoid do 2 different things for CPU and CUDA >:(
|
||||
|
||||
@register_decomposition(aten.softshrink_backward)
|
||||
def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
|
||||
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_output, ()), grad_output)
|
||||
|
||||
|
||||
@register_decomposition(aten.prelu_backward)
|
||||
def prelu_backward(grad_output: Tensor, self: Tensor, weight: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
spatial_dims = list(range(2, grad_output.dim()))
|
||||
for _ in range(len(spatial_dims)):
|
||||
weight = weight.unsqueeze(-1)
|
||||
input_grad = aten.where(self > 0, grad_output, weight * grad_output)
|
||||
weight_grad_collector = aten.where(self > 0, aten.new_zeros(grad_output, ()), self * grad_output)
|
||||
return (input_grad, aten.sum(weight_grad_collector, [0] + spatial_dims))
|
||||
|
||||
|
||||
@register_decomposition(aten.log_sigmoid_backward)
|
||||
@ -124,11 +137,10 @@ def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> T
|
||||
in_negative = self < 0
|
||||
max_deriv = aten.where(in_negative, 1, 0)
|
||||
sign = aten.where(in_negative, 1, -1)
|
||||
if grad_output.is_cuda: # buffer is not used on CUDA
|
||||
z = aten.exp(-aten.abs(self))
|
||||
return grad_output * (max_deriv - sign * (z / (1 + z)))
|
||||
else:
|
||||
return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
|
||||
z = aten.exp(-aten.abs(self))
|
||||
return grad_output * (max_deriv - sign * (z / (1 + z)))
|
||||
# CPU has a special formula that uses buffer, but disabled for convenience sake
|
||||
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
|
||||
|
||||
|
||||
@register_decomposition(aten.mse_loss_backward)
|
||||
@ -196,6 +208,14 @@ def im2col_backward(
|
||||
return aten.col2im(grad_output, input_size, kernel_size, dilation, padding, stride)
|
||||
|
||||
|
||||
@register_decomposition(aten.col2im_backward)
|
||||
def col2im_backward(
|
||||
grad_output: Tensor, kernel_size: List[int],
|
||||
dilation: List[int], padding: List[int], stride: List[int]
|
||||
) -> Tensor:
|
||||
return aten.im2col(grad_output, kernel_size, dilation, padding, stride)
|
||||
|
||||
|
||||
@register_decomposition(aten.native_dropout_backward)
|
||||
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
|
||||
return grad_output * (mask.type_as(grad_output) * scale)
|
||||
@ -244,6 +264,7 @@ def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
|
||||
return self + value * (tensor1 / tensor2)
|
||||
|
||||
|
||||
# Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
|
||||
@register_decomposition(aten.addcmul)
|
||||
def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
|
||||
if self.is_floating_point():
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <functorch/csrc/Constants.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <functorch/csrc/TensorWrapper.h>
|
||||
#include <functorch/csrc/BatchedTensorImpl.h>
|
||||
|
||||
|
@ -153,11 +153,12 @@ def gen_data(special_op_lists, analysis_name):
|
||||
if op['name'] in annotated_ops:
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core ' + annotated_ops[op['name']]
|
||||
else:
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core unknown'
|
||||
continue
|
||||
categorization['core'] += 1
|
||||
op['meta'] = 'core unknown'
|
||||
return categorization
|
||||
|
||||
annotate_ops(ops, is_unique=False)
|
||||
with open(f"{analysis_name}", 'w') as f:
|
||||
for op in ops:
|
||||
info = [
|
||||
@ -178,9 +179,10 @@ def full_name_check(lst):
|
||||
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap')
|
||||
|
||||
|
||||
if False:
|
||||
if True:
|
||||
with open('run_ops.txt', 'r') as f:
|
||||
opinfo_ops = [i.strip() for i in f.readlines()]
|
||||
with open('run_decompositions.txt', 'r') as f:
|
||||
decomposed_ops = [i.strip() for i in f.readlines()]
|
||||
gen_data([name_check(opinfo_ops), name_check(decomposed_ops)], 'decompositions')
|
||||
|
||||
|
Reference in New Issue
Block a user