[functorch] Added softshrink_backward, prelu_backward, and col2im_backward decompositions

This commit is contained in:
Horace He
2022-02-17 02:35:01 +00:00
committed by Jon Janzen
parent ae24d77e36
commit 5bbb1b039d
3 changed files with 35 additions and 11 deletions

View File

@ -1,6 +1,6 @@
import torch
from torch import Tensor
from typing import Optional, List
from typing import Optional, List, Tuple
from enum import Enum
aten = torch.ops.aten
@ -116,7 +116,20 @@ def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
sigmoid = 1 / (1 + aten.exp(aten.neg(self)))
return grad_output * sigmoid * (1 + self * (1 - sigmoid))
# whyyyy does log_sigmoid do 2 different things for CPU and CUDA >:(
@register_decomposition(aten.softshrink_backward)
def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.prelu_backward)
def prelu_backward(grad_output: Tensor, self: Tensor, weight: Tensor) -> Tuple[Tensor, Tensor]:
spatial_dims = list(range(2, grad_output.dim()))
for _ in range(len(spatial_dims)):
weight = weight.unsqueeze(-1)
input_grad = aten.where(self > 0, grad_output, weight * grad_output)
weight_grad_collector = aten.where(self > 0, aten.new_zeros(grad_output, ()), self * grad_output)
return (input_grad, aten.sum(weight_grad_collector, [0] + spatial_dims))
@register_decomposition(aten.log_sigmoid_backward)
@ -124,11 +137,10 @@ def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> T
in_negative = self < 0
max_deriv = aten.where(in_negative, 1, 0)
sign = aten.where(in_negative, 1, -1)
if grad_output.is_cuda: # buffer is not used on CUDA
z = aten.exp(-aten.abs(self))
return grad_output * (max_deriv - sign * (z / (1 + z)))
else:
return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
z = aten.exp(-aten.abs(self))
return grad_output * (max_deriv - sign * (z / (1 + z)))
# CPU has a special formula that uses buffer, but disabled for convenience sake
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
@register_decomposition(aten.mse_loss_backward)
@ -196,6 +208,14 @@ def im2col_backward(
return aten.col2im(grad_output, input_size, kernel_size, dilation, padding, stride)
@register_decomposition(aten.col2im_backward)
def col2im_backward(
grad_output: Tensor, kernel_size: List[int],
dilation: List[int], padding: List[int], stride: List[int]
) -> Tensor:
return aten.im2col(grad_output, kernel_size, dilation, padding, stride)
@register_decomposition(aten.native_dropout_backward)
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
return grad_output * (mask.type_as(grad_output) * scale)
@ -244,6 +264,7 @@ def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
return self + value * (tensor1 / tensor2)
# Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
@register_decomposition(aten.addcmul)
def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
if self.is_floating_point():

View File

@ -2,6 +2,7 @@
#include <functorch/csrc/Constants.h>
#include <torch/library.h>
#include <ATen/ATen.h>
#include <ATen/WrapDimUtils.h>
#include <functorch/csrc/TensorWrapper.h>
#include <functorch/csrc/BatchedTensorImpl.h>

View File

@ -153,11 +153,12 @@ def gen_data(special_op_lists, analysis_name):
if op['name'] in annotated_ops:
categorization['core'] += 1
op['meta'] = 'core ' + annotated_ops[op['name']]
else:
categorization['core'] += 1
op['meta'] = 'core unknown'
continue
categorization['core'] += 1
op['meta'] = 'core unknown'
return categorization
annotate_ops(ops, is_unique=False)
with open(f"{analysis_name}", 'w') as f:
for op in ops:
info = [
@ -178,9 +179,10 @@ def full_name_check(lst):
gen_data([full_name_check(get_ops_for_key('FuncTorchBatched'))], 'vmap')
if False:
if True:
with open('run_ops.txt', 'r') as f:
opinfo_ops = [i.strip() for i in f.readlines()]
with open('run_decompositions.txt', 'r') as f:
decomposed_ops = [i.strip() for i in f.readlines()]
gen_data([name_check(opinfo_ops), name_check(decomposed_ops)], 'decompositions')