mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] moved decompositions to their own file
This commit is contained in:
1
functorch/.gitignore
vendored
1
functorch/.gitignore
vendored
@ -13,3 +13,4 @@ docs/build
|
||||
docs/src
|
||||
docs/source/generated
|
||||
.DS_Store
|
||||
op_analysis/*.txt
|
||||
|
88
functorch/functorch/_src/decompositions.py
Normal file
88
functorch/functorch/_src/decompositions.py
Normal file
@ -0,0 +1,88 @@
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from enum import Enum
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
decomposition_table = {}
|
||||
|
||||
def register_decomposition(aten_op):
|
||||
def decomposition_decorator(f):
|
||||
decomposition_table[aten_op] = f
|
||||
return f
|
||||
return decomposition_decorator
|
||||
|
||||
class Reduction(Enum):
|
||||
NONE = 0
|
||||
MEAN = 1
|
||||
SUM = 2
|
||||
|
||||
@register_decomposition(aten.tanh_backward)
|
||||
def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||
return out_grad * (1 - y * y)
|
||||
|
||||
@register_decomposition(aten.sigmoid_backward)
|
||||
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||
return out_grad * (y * (1 - y))
|
||||
|
||||
@register_decomposition(aten.softplus_backward)
|
||||
# The out argument seems to always be ignored?
|
||||
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
|
||||
z = (x * beta).exp()
|
||||
return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
|
||||
|
||||
@register_decomposition(aten.elu_backward)
|
||||
def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor):
|
||||
negcoef = alpha * scale
|
||||
poscoef = scale
|
||||
negiptcoef = input_scale
|
||||
if is_result:
|
||||
return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef)
|
||||
else:
|
||||
return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef)
|
||||
|
||||
@register_decomposition(aten.hardsigmoid_backward)
|
||||
def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor):
|
||||
return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ()))
|
||||
|
||||
@register_decomposition(aten.hardtanh_backward)
|
||||
def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
|
||||
return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output)
|
||||
|
||||
@register_decomposition(aten.hardshrink_backward)
|
||||
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
|
||||
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
|
||||
|
||||
@register_decomposition(aten.threshold_backward)
|
||||
def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
|
||||
return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
|
||||
|
||||
@register_decomposition(aten.leaky_relu_backward)
|
||||
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
|
||||
return aten.where(self > 0, grad_output, grad_output * negative_slope)
|
||||
|
||||
@register_decomposition(aten.mse_loss_backward)
|
||||
def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int):
|
||||
norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2.
|
||||
return norm * (input - target) * grad_output
|
||||
|
||||
@register_decomposition(aten.huber_loss_backward)
|
||||
def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
|
||||
norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1.
|
||||
x = self - target
|
||||
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
|
||||
|
||||
# @register_decomposition(aten._fused_dropout)
|
||||
# def _fused_dropout_decomposition(input, p, generator=None):
|
||||
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
|
||||
# res = mask.type_as(input) * input * (1./p)
|
||||
# return [res, mask]
|
||||
|
||||
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
|
||||
@register_decomposition(aten.detach)
|
||||
def detach_decomposition(x: Tensor):
|
||||
return x
|
||||
|
||||
@register_decomposition(aten._s_where)
|
||||
def _s_where_canonicalization(a, b, c):
|
||||
return aten.where(a, b, c)
|
@ -15,94 +15,11 @@ import torch.fx as fx
|
||||
import torch.fx._pytree as fx_pytree
|
||||
from torch import Tensor
|
||||
from .nnc_compile import nnc_compile
|
||||
from .decompositions import decomposition_table
|
||||
from enum import Enum
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
decomposition_table = {}
|
||||
|
||||
def register_decomposition(aten_op):
|
||||
def decomposition_decorator(f):
|
||||
decomposition_table[aten_op] = f
|
||||
return f
|
||||
return decomposition_decorator
|
||||
|
||||
class Reduction(Enum):
|
||||
NONE = 0
|
||||
MEAN = 1
|
||||
SUM = 2
|
||||
|
||||
@register_decomposition(aten.tanh_backward)
|
||||
def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||
return out_grad * (1 - y * y)
|
||||
|
||||
@register_decomposition(aten.sigmoid_backward)
|
||||
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||
return out_grad * (y * (1 - y))
|
||||
|
||||
@register_decomposition(aten.softplus_backward)
|
||||
# The out argument seems to always be ignored?
|
||||
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
|
||||
z = (x * beta).exp()
|
||||
return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
|
||||
|
||||
@register_decomposition(aten.elu_backward)
|
||||
def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor):
|
||||
negcoef = alpha * scale
|
||||
poscoef = scale
|
||||
negiptcoef = input_scale
|
||||
if is_result:
|
||||
return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef)
|
||||
else:
|
||||
return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef)
|
||||
|
||||
@register_decomposition(aten.hardsigmoid_backward)
|
||||
def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor):
|
||||
return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ()))
|
||||
|
||||
@register_decomposition(aten.hardtanh_backward)
|
||||
def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float):
|
||||
return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output)
|
||||
|
||||
@register_decomposition(aten.hardshrink_backward)
|
||||
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
|
||||
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
|
||||
|
||||
@register_decomposition(aten.threshold_backward)
|
||||
def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
|
||||
return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
|
||||
|
||||
@register_decomposition(aten.leaky_relu_backward)
|
||||
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
|
||||
return aten.where(self > 0, grad_output, grad_output * negative_slope)
|
||||
|
||||
@register_decomposition(aten.mse_loss_backward)
|
||||
def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int):
|
||||
norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2.
|
||||
return norm * (input - target) * grad_output
|
||||
|
||||
@register_decomposition(aten.huber_loss_backward)
|
||||
def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float):
|
||||
norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1.
|
||||
x = self - target
|
||||
return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output))
|
||||
|
||||
# @register_decomposition(aten._fused_dropout)
|
||||
# def _fused_dropout_decomposition(input, p, generator=None):
|
||||
# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8)
|
||||
# res = mask.type_as(input) * input * (1./p)
|
||||
# return [res, mask]
|
||||
|
||||
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
|
||||
@register_decomposition(aten.detach)
|
||||
def detach_decomposition(x: Tensor):
|
||||
return x
|
||||
|
||||
@register_decomposition(aten._s_where)
|
||||
def _s_where_canonicalization(a, b, c):
|
||||
return aten.where(a, b, c)
|
||||
|
||||
USE_DECOMPOSE = False
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
from .._src.operator_authoring import pointwise_operator
|
||||
from .._src.memory_efficient_op_authoring import memory_efficient_operator_authoring, torchscript_nvfuser_compile
|
||||
from .._src.python_key import nnc_jit, make_nnc, register_decomposition, pythonkey_decompose, decomposition_table
|
||||
from .._src.python_key import nnc_jit, make_nnc, pythonkey_decompose
|
||||
from .._src.decompositions import register_decomposition, decomposition_table
|
||||
from .._src.nnc_compile import nnc_compile, get_ops
|
||||
from .._src.aot_autograd import (
|
||||
compiled_function,
|
||||
|
@ -1,11 +0,0 @@
|
||||
_s_where
|
||||
elu_backward
|
||||
hardshrink_backward
|
||||
hardsigmoid_backward
|
||||
hardtanh_backward
|
||||
huber_loss_backward
|
||||
leaky_relu_backward
|
||||
mse_loss_backward
|
||||
sigmoid_backward
|
||||
softplus_backward
|
||||
tanh_backward
|
@ -1,364 +0,0 @@
|
||||
__and__
|
||||
__or__
|
||||
_adaptive_avg_pool2d
|
||||
_adaptive_avg_pool2d_backward
|
||||
_adaptive_avg_pool3d
|
||||
_adaptive_avg_pool3d_backward
|
||||
_cdist_backward
|
||||
_cdist_forward
|
||||
_conj
|
||||
_conj_physical
|
||||
_det_lu_based_helper
|
||||
_euclidean_dist
|
||||
_fft_c2c
|
||||
_fft_c2r
|
||||
_fft_r2c
|
||||
_histogramdd_bin_edges
|
||||
_histogramdd_from_bin_cts
|
||||
_histogramdd_from_bin_tensors
|
||||
_local_scalar_dense
|
||||
_log_softmax
|
||||
_log_softmax_backward_data
|
||||
_lu_with_info
|
||||
_reshape_alias
|
||||
_slow_conv2d_backward
|
||||
_slow_conv2d_forward
|
||||
_softmax
|
||||
_softmax_backward_data
|
||||
_svd_helper
|
||||
_to_copy
|
||||
_unique2
|
||||
_unsafe_view
|
||||
abs
|
||||
acos
|
||||
acosh
|
||||
adaptive_max_pool2d
|
||||
adaptive_max_pool2d_backward
|
||||
adaptive_max_pool3d
|
||||
adaptive_max_pool3d_backward
|
||||
add
|
||||
add_
|
||||
addbmm
|
||||
addcdiv
|
||||
addcmul
|
||||
addcmul_
|
||||
addmm
|
||||
addmv
|
||||
addr
|
||||
alias
|
||||
all
|
||||
amax
|
||||
amin
|
||||
aminmax
|
||||
angle
|
||||
any
|
||||
argmax
|
||||
argmin
|
||||
asin
|
||||
asinh
|
||||
atan
|
||||
atan2
|
||||
atanh
|
||||
avg_pool2d
|
||||
avg_pool2d_backward
|
||||
avg_pool3d
|
||||
avg_pool3d_backward
|
||||
baddbmm
|
||||
bernoulli_
|
||||
bincount
|
||||
bitwise_and
|
||||
bitwise_and_
|
||||
bitwise_left_shift
|
||||
bitwise_not
|
||||
bitwise_or
|
||||
bitwise_or_
|
||||
bitwise_right_shift
|
||||
bitwise_xor
|
||||
bmm
|
||||
bucketize
|
||||
cat
|
||||
ceil
|
||||
ceil_
|
||||
celu
|
||||
cholesky
|
||||
cholesky_inverse
|
||||
cholesky_solve
|
||||
clamp
|
||||
clamp_
|
||||
clamp_min
|
||||
clamp_min_
|
||||
clone
|
||||
complex
|
||||
constant_pad_nd
|
||||
copy_
|
||||
copysign
|
||||
cos
|
||||
cosh
|
||||
count_nonzero
|
||||
cummax
|
||||
cummin
|
||||
cumprod
|
||||
cumsum
|
||||
deg2rad
|
||||
detach
|
||||
diag
|
||||
diagonal
|
||||
diagonal_scatter
|
||||
digamma
|
||||
dist
|
||||
div
|
||||
div_
|
||||
dot
|
||||
eig
|
||||
embedding
|
||||
empty_like
|
||||
eq
|
||||
erf
|
||||
erf_
|
||||
erfc
|
||||
erfinv
|
||||
exp
|
||||
exp2
|
||||
exp_
|
||||
expand
|
||||
expm1
|
||||
fill_
|
||||
flip
|
||||
floor
|
||||
floor_divide
|
||||
fmax
|
||||
fmin
|
||||
fmod
|
||||
frac
|
||||
frexp
|
||||
full_like
|
||||
gather
|
||||
ge
|
||||
gelu
|
||||
geqrf
|
||||
grid_sampler_2d
|
||||
grid_sampler_2d_backward
|
||||
grid_sampler_3d
|
||||
grid_sampler_3d_backward
|
||||
gt
|
||||
hardshrink
|
||||
hardsigmoid
|
||||
hardswish
|
||||
hardswish_backward
|
||||
hardtanh
|
||||
histogram
|
||||
huber_loss
|
||||
hypot
|
||||
i0
|
||||
igamma
|
||||
igammac
|
||||
im2col
|
||||
index
|
||||
index_add_
|
||||
index_copy_
|
||||
index_fill_
|
||||
index_put_
|
||||
index_select
|
||||
inverse
|
||||
isin
|
||||
isnan
|
||||
isneginf
|
||||
isposinf
|
||||
kthvalue
|
||||
le
|
||||
leaky_relu
|
||||
lerp
|
||||
lerp_
|
||||
lgamma
|
||||
linalg_cholesky_ex
|
||||
linalg_cross
|
||||
linalg_eig
|
||||
linalg_eigh
|
||||
linalg_eigvalsh
|
||||
linalg_householder_product
|
||||
linalg_inv_ex
|
||||
linalg_lstsq
|
||||
linalg_matrix_exp
|
||||
linalg_pinv
|
||||
linalg_qr
|
||||
linalg_slogdet
|
||||
linalg_solve
|
||||
linalg_vector_norm
|
||||
log
|
||||
log10
|
||||
log1p
|
||||
log2
|
||||
log_sigmoid_backward
|
||||
log_sigmoid_forward
|
||||
logaddexp
|
||||
logaddexp2
|
||||
logcumsumexp
|
||||
logdet
|
||||
logical_and
|
||||
logical_not
|
||||
logical_or
|
||||
logical_xor
|
||||
logit
|
||||
logsumexp
|
||||
lt
|
||||
lu_solve
|
||||
lu_unpack
|
||||
masked_fill_
|
||||
masked_scatter_
|
||||
masked_select
|
||||
max
|
||||
max_pool2d_with_indices
|
||||
max_pool2d_with_indices_backward
|
||||
maximum
|
||||
mean
|
||||
median
|
||||
min
|
||||
minimum
|
||||
mish
|
||||
mkldnn_convolution
|
||||
mkldnn_convolution_backward
|
||||
mm
|
||||
mode
|
||||
mse_loss
|
||||
mul
|
||||
mul_
|
||||
mv
|
||||
nan_to_num
|
||||
nanmedian
|
||||
nansum
|
||||
native_batch_norm
|
||||
native_batch_norm_backward
|
||||
native_group_norm
|
||||
native_layer_norm
|
||||
native_layer_norm_backward
|
||||
ne
|
||||
neg
|
||||
new_empty
|
||||
nextafter
|
||||
nll_loss2d_backward
|
||||
nll_loss2d_forward
|
||||
nll_loss_backward
|
||||
nll_loss_forward
|
||||
nonzero
|
||||
norm
|
||||
ones_like
|
||||
ormqr
|
||||
permute
|
||||
polar
|
||||
polygamma
|
||||
pow
|
||||
pow_
|
||||
prod
|
||||
put_
|
||||
rad2deg
|
||||
randint_like
|
||||
randn_like
|
||||
reciprocal
|
||||
reciprocal_
|
||||
reflection_pad1d
|
||||
reflection_pad1d_backward
|
||||
reflection_pad2d
|
||||
reflection_pad2d_backward
|
||||
reflection_pad3d
|
||||
reflection_pad3d_backward
|
||||
relu
|
||||
remainder
|
||||
renorm
|
||||
repeat
|
||||
repeat_interleave
|
||||
replication_pad1d
|
||||
replication_pad1d_backward
|
||||
replication_pad2d
|
||||
replication_pad2d_backward
|
||||
replication_pad3d
|
||||
replication_pad3d_backward
|
||||
resize_as_
|
||||
roll
|
||||
rot90
|
||||
round
|
||||
rsqrt
|
||||
rsub
|
||||
scatter
|
||||
scatter_
|
||||
scatter_add
|
||||
scatter_add_
|
||||
searchsorted
|
||||
select
|
||||
select_backward
|
||||
select_scatter
|
||||
sgn
|
||||
sigmoid
|
||||
sign
|
||||
signbit
|
||||
sin
|
||||
sinc
|
||||
sinh
|
||||
slice
|
||||
slice_backward
|
||||
slice_scatter
|
||||
slow_conv_dilated2d
|
||||
slow_conv_dilated2d_backward
|
||||
slow_conv_transpose2d
|
||||
slow_conv_transpose2d_backward
|
||||
softplus
|
||||
solve
|
||||
sort
|
||||
special_entr
|
||||
special_erfcx
|
||||
special_i0e
|
||||
special_i1
|
||||
special_i1e
|
||||
special_ndtri
|
||||
special_xlog1py
|
||||
special_zeta
|
||||
split
|
||||
split_with_sizes
|
||||
sqrt
|
||||
sqrt_
|
||||
squeeze
|
||||
squeeze_
|
||||
stack
|
||||
std
|
||||
std_mean
|
||||
sub
|
||||
sub_
|
||||
sum
|
||||
symeig
|
||||
t
|
||||
take
|
||||
tan
|
||||
tanh
|
||||
threshold
|
||||
threshold_backward
|
||||
to_sparse
|
||||
topk
|
||||
trace
|
||||
transpose
|
||||
triangular_solve
|
||||
tril
|
||||
tril_
|
||||
triu
|
||||
triu_
|
||||
trunc
|
||||
unbind
|
||||
unfold
|
||||
unique_consecutive
|
||||
unique_dim
|
||||
unsqueeze
|
||||
unsqueeze_
|
||||
upsample_bicubic2d
|
||||
upsample_bilinear2d
|
||||
upsample_linear1d
|
||||
upsample_nearest1d
|
||||
upsample_nearest2d
|
||||
upsample_nearest3d
|
||||
upsample_trilinear3d
|
||||
var
|
||||
var_mean
|
||||
vdot
|
||||
view
|
||||
view_as_real
|
||||
where
|
||||
xlogy
|
||||
zero_
|
||||
zeros_like
|
Reference in New Issue
Block a user