[functorch] moved decompositions to their own file

This commit is contained in:
Horace He
2021-11-30 19:48:04 +00:00
committed by Jon Janzen
parent 653e56b6b0
commit d1f8127043
6 changed files with 92 additions and 460 deletions

View File

@ -13,3 +13,4 @@ docs/build
docs/src
docs/source/generated
.DS_Store
op_analysis/*.txt

View File

@ -0,0 +1,88 @@
import torch
from torch import Tensor
from enum import Enum
aten = torch.ops.aten
decomposition_table = {}
def register_decomposition(aten_op):
def decomposition_decorator(f):
decomposition_table[aten_op] = f
return f
return decomposition_decorator
class Reduction(Enum):
NONE = 0
MEAN = 1
SUM = 2
@register_decomposition(aten.tanh_backward)
def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
return out_grad * (1 - y * y)
@register_decomposition(aten.sigmoid_backward)
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
return out_grad * (y * (1 - y))
@register_decomposition(aten.softplus_backward)
# The out argument seems to always be ignored?
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
z = (x * beta).exp()
return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
@register_decomposition(aten.elu_backward)
def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor):
negcoef = alpha * scale
poscoef = scale
negiptcoef = input_scale
if is_result:
return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef)
else:
return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef)
@register_decomposition(aten.hardsigmoid_backward)
def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor):
return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ()))
@register_decomposition(aten.hardtanh_backward)
def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.hardshrink_backward)
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
@register_decomposition(aten.threshold_backward)
def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.leaky_relu_backward)
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
return aten.where(self > 0, grad_output, grad_output * negative_slope)
@register_decomposition(aten.mse_loss_backward)
def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int):
norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2.
return norm * (input - target) * grad_output
@register_decomposition(aten.huber_loss_backward)
def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1.
x = self - target
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
# @register_decomposition(aten._fused_dropout)
# def _fused_dropout_decomposition(input, p, generator=None):
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
# res = mask.type_as(input) * input * (1./p)
# return [res, mask]
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
@register_decomposition(aten.detach)
def detach_decomposition(x: Tensor):
return x
@register_decomposition(aten._s_where)
def _s_where_canonicalization(a, b, c):
return aten.where(a, b, c)

View File

@ -15,94 +15,11 @@ import torch.fx as fx
import torch.fx._pytree as fx_pytree
from torch import Tensor
from .nnc_compile import nnc_compile
from .decompositions import decomposition_table
from enum import Enum
import warnings
from contextlib import contextmanager
aten = torch.ops.aten
decomposition_table = {}
def register_decomposition(aten_op):
def decomposition_decorator(f):
decomposition_table[aten_op] = f
return f
return decomposition_decorator
class Reduction(Enum):
NONE = 0
MEAN = 1
SUM = 2
@register_decomposition(aten.tanh_backward)
def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
return out_grad * (1 - y * y)
@register_decomposition(aten.sigmoid_backward)
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
return out_grad * (y * (1 - y))
@register_decomposition(aten.softplus_backward)
# The out argument seems to always be ignored?
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
z = (x * beta).exp()
return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
@register_decomposition(aten.elu_backward)
def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor):
negcoef = alpha * scale
poscoef = scale
negiptcoef = input_scale
if is_result:
return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef)
else:
return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef)
@register_decomposition(aten.hardsigmoid_backward)
def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor):
return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ()))
@register_decomposition(aten.hardtanh_backward)
def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.hardshrink_backward)
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
@register_decomposition(aten.threshold_backward)
def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.leaky_relu_backward)
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
return aten.where(self > 0, grad_output, grad_output * negative_slope)
@register_decomposition(aten.mse_loss_backward)
def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int):
norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2.
return norm * (input - target) * grad_output
@register_decomposition(aten.huber_loss_backward)
def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1.
x = self - target
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
# @register_decomposition(aten._fused_dropout)
# def _fused_dropout_decomposition(input, p, generator=None):
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
# res = mask.type_as(input) * input * (1./p)
# return [res, mask]
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
@register_decomposition(aten.detach)
def detach_decomposition(x: Tensor):
return x
@register_decomposition(aten._s_where)
def _s_where_canonicalization(a, b, c):
return aten.where(a, b, c)
USE_DECOMPOSE = False

View File

@ -1,6 +1,7 @@
from .._src.operator_authoring import pointwise_operator
from .._src.memory_efficient_op_authoring import memory_efficient_operator_authoring, torchscript_nvfuser_compile
from .._src.python_key import nnc_jit, make_nnc, register_decomposition, pythonkey_decompose, decomposition_table
from .._src.python_key import nnc_jit, make_nnc, pythonkey_decompose
from .._src.decompositions import register_decomposition, decomposition_table
from .._src.nnc_compile import nnc_compile, get_ops
from .._src.aot_autograd import (
compiled_function,

View File

@ -1,11 +0,0 @@
_s_where
elu_backward
hardshrink_backward
hardsigmoid_backward
hardtanh_backward
huber_loss_backward
leaky_relu_backward
mse_loss_backward
sigmoid_backward
softplus_backward
tanh_backward

View File

@ -1,364 +0,0 @@
__and__
__or__
_adaptive_avg_pool2d
_adaptive_avg_pool2d_backward
_adaptive_avg_pool3d
_adaptive_avg_pool3d_backward
_cdist_backward
_cdist_forward
_conj
_conj_physical
_det_lu_based_helper
_euclidean_dist
_fft_c2c
_fft_c2r
_fft_r2c
_histogramdd_bin_edges
_histogramdd_from_bin_cts
_histogramdd_from_bin_tensors
_local_scalar_dense
_log_softmax
_log_softmax_backward_data
_lu_with_info
_reshape_alias
_slow_conv2d_backward
_slow_conv2d_forward
_softmax
_softmax_backward_data
_svd_helper
_to_copy
_unique2
_unsafe_view
abs
acos
acosh
adaptive_max_pool2d
adaptive_max_pool2d_backward
adaptive_max_pool3d
adaptive_max_pool3d_backward
add
add_
addbmm
addcdiv
addcmul
addcmul_
addmm
addmv
addr
alias
all
amax
amin
aminmax
angle
any
argmax
argmin
asin
asinh
atan
atan2
atanh
avg_pool2d
avg_pool2d_backward
avg_pool3d
avg_pool3d_backward
baddbmm
bernoulli_
bincount
bitwise_and
bitwise_and_
bitwise_left_shift
bitwise_not
bitwise_or
bitwise_or_
bitwise_right_shift
bitwise_xor
bmm
bucketize
cat
ceil
ceil_
celu
cholesky
cholesky_inverse
cholesky_solve
clamp
clamp_
clamp_min
clamp_min_
clone
complex
constant_pad_nd
copy_
copysign
cos
cosh
count_nonzero
cummax
cummin
cumprod
cumsum
deg2rad
detach
diag
diagonal
diagonal_scatter
digamma
dist
div
div_
dot
eig
embedding
empty_like
eq
erf
erf_
erfc
erfinv
exp
exp2
exp_
expand
expm1
fill_
flip
floor
floor_divide
fmax
fmin
fmod
frac
frexp
full_like
gather
ge
gelu
geqrf
grid_sampler_2d
grid_sampler_2d_backward
grid_sampler_3d
grid_sampler_3d_backward
gt
hardshrink
hardsigmoid
hardswish
hardswish_backward
hardtanh
histogram
huber_loss
hypot
i0
igamma
igammac
im2col
index
index_add_
index_copy_
index_fill_
index_put_
index_select
inverse
isin
isnan
isneginf
isposinf
kthvalue
le
leaky_relu
lerp
lerp_
lgamma
linalg_cholesky_ex
linalg_cross
linalg_eig
linalg_eigh
linalg_eigvalsh
linalg_householder_product
linalg_inv_ex
linalg_lstsq
linalg_matrix_exp
linalg_pinv
linalg_qr
linalg_slogdet
linalg_solve
linalg_vector_norm
log
log10
log1p
log2
log_sigmoid_backward
log_sigmoid_forward
logaddexp
logaddexp2
logcumsumexp
logdet
logical_and
logical_not
logical_or
logical_xor
logit
logsumexp
lt
lu_solve
lu_unpack
masked_fill_
masked_scatter_
masked_select
max
max_pool2d_with_indices
max_pool2d_with_indices_backward
maximum
mean
median
min
minimum
mish
mkldnn_convolution
mkldnn_convolution_backward
mm
mode
mse_loss
mul
mul_
mv
nan_to_num
nanmedian
nansum
native_batch_norm
native_batch_norm_backward
native_group_norm
native_layer_norm
native_layer_norm_backward
ne
neg
new_empty
nextafter
nll_loss2d_backward
nll_loss2d_forward
nll_loss_backward
nll_loss_forward
nonzero
norm
ones_like
ormqr
permute
polar
polygamma
pow
pow_
prod
put_
rad2deg
randint_like
randn_like
reciprocal
reciprocal_
reflection_pad1d
reflection_pad1d_backward
reflection_pad2d
reflection_pad2d_backward
reflection_pad3d
reflection_pad3d_backward
relu
remainder
renorm
repeat
repeat_interleave
replication_pad1d
replication_pad1d_backward
replication_pad2d
replication_pad2d_backward
replication_pad3d
replication_pad3d_backward
resize_as_
roll
rot90
round
rsqrt
rsub
scatter
scatter_
scatter_add
scatter_add_
searchsorted
select
select_backward
select_scatter
sgn
sigmoid
sign
signbit
sin
sinc
sinh
slice
slice_backward
slice_scatter
slow_conv_dilated2d
slow_conv_dilated2d_backward
slow_conv_transpose2d
slow_conv_transpose2d_backward
softplus
solve
sort
special_entr
special_erfcx
special_i0e
special_i1
special_i1e
special_ndtri
special_xlog1py
special_zeta
split
split_with_sizes
sqrt
sqrt_
squeeze
squeeze_
stack
std
std_mean
sub
sub_
sum
symeig
t
take
tan
tanh
threshold
threshold_backward
to_sparse
topk
trace
transpose
triangular_solve
tril
tril_
triu
triu_
trunc
unbind
unfold
unique_consecutive
unique_dim
unsqueeze
unsqueeze_
upsample_bicubic2d
upsample_bilinear2d
upsample_linear1d
upsample_nearest1d
upsample_nearest2d
upsample_nearest3d
upsample_trilinear3d
var
var_mean
vdot
view
view_as_real
where
xlogy
zero_
zeros_like