From d1f8127043633ca4eabc13c29d6d80f205977b42 Mon Sep 17 00:00:00 2001 From: Horace He Date: Tue, 30 Nov 2021 19:48:04 +0000 Subject: [PATCH] [functorch] moved decompositions to their own file --- functorch/.gitignore | 1 + functorch/functorch/_src/decompositions.py | 88 +++++ functorch/functorch/_src/python_key.py | 85 +---- functorch/functorch/compile/__init__.py | 3 +- functorch/op_analysis/run_decompositions.txt | 11 - functorch/op_analysis/run_ops.txt | 364 ------------------- 6 files changed, 92 insertions(+), 460 deletions(-) create mode 100644 functorch/functorch/_src/decompositions.py delete mode 100644 functorch/op_analysis/run_decompositions.txt delete mode 100644 functorch/op_analysis/run_ops.txt diff --git a/functorch/.gitignore b/functorch/.gitignore index c8d3aa177317..e0872b7cf0c3 100644 --- a/functorch/.gitignore +++ b/functorch/.gitignore @@ -13,3 +13,4 @@ docs/build docs/src docs/source/generated .DS_Store +op_analysis/*.txt diff --git a/functorch/functorch/_src/decompositions.py b/functorch/functorch/_src/decompositions.py new file mode 100644 index 000000000000..2bde73e22ae2 --- /dev/null +++ b/functorch/functorch/_src/decompositions.py @@ -0,0 +1,88 @@ +import torch +from torch import Tensor +from enum import Enum + +aten = torch.ops.aten + +decomposition_table = {} + +def register_decomposition(aten_op): + def decomposition_decorator(f): + decomposition_table[aten_op] = f + return f + return decomposition_decorator + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + +@register_decomposition(aten.tanh_backward) +def tanh_backward_decomposition(out_grad: Tensor, y: Tensor): + return out_grad * (1 - y * y) + +@register_decomposition(aten.sigmoid_backward) +def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor): + return out_grad * (y * (1 - y)) + +@register_decomposition(aten.softplus_backward) +# The out argument seems to always be ignored? +def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out): + z = (x * beta).exp() + return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) + +@register_decomposition(aten.elu_backward) +def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor): + negcoef = alpha * scale + poscoef = scale + negiptcoef = input_scale + if is_result: + return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef) + else: + return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef) + +@register_decomposition(aten.hardsigmoid_backward) +def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor): + return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ())) + +@register_decomposition(aten.hardtanh_backward) +def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float): + return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output) + +@register_decomposition(aten.hardshrink_backward) +def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float): + return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out) + +@register_decomposition(aten.threshold_backward) +def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float): + return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output) + +@register_decomposition(aten.leaky_relu_backward) +def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool): + return aten.where(self > 0, grad_output, grad_output * negative_slope) + +@register_decomposition(aten.mse_loss_backward) +def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int): + norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2. + return norm * (input - target) * grad_output + +@register_decomposition(aten.huber_loss_backward) +def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float): + norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1. + x = self - target + return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output)) + +# @register_decomposition(aten._fused_dropout) +# def _fused_dropout_decomposition(input, p, generator=None): +# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8) +# res = mask.type_as(input) * input * (1./p) +# return [res, mask] + +# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. +@register_decomposition(aten.detach) +def detach_decomposition(x: Tensor): + return x + +@register_decomposition(aten._s_where) +def _s_where_canonicalization(a, b, c): + return aten.where(a, b, c) \ No newline at end of file diff --git a/functorch/functorch/_src/python_key.py b/functorch/functorch/_src/python_key.py index f3f52be53bea..6156d97f8cb2 100644 --- a/functorch/functorch/_src/python_key.py +++ b/functorch/functorch/_src/python_key.py @@ -15,94 +15,11 @@ import torch.fx as fx import torch.fx._pytree as fx_pytree from torch import Tensor from .nnc_compile import nnc_compile +from .decompositions import decomposition_table from enum import Enum import warnings from contextlib import contextmanager -aten = torch.ops.aten - -decomposition_table = {} - -def register_decomposition(aten_op): - def decomposition_decorator(f): - decomposition_table[aten_op] = f - return f - return decomposition_decorator - -class Reduction(Enum): - NONE = 0 - MEAN = 1 - SUM = 2 - -@register_decomposition(aten.tanh_backward) -def tanh_backward_decomposition(out_grad: Tensor, y: Tensor): - return out_grad * (1 - y * y) - -@register_decomposition(aten.sigmoid_backward) -def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor): - return out_grad * (y * (1 - y)) - -@register_decomposition(aten.softplus_backward) -# The out argument seems to always be ignored? -def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out): - z = (x * beta).exp() - return aten.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0)) - -@register_decomposition(aten.elu_backward) -def elu_backward_decomposition(grad_output: Tensor, alpha: float, scale: float, input_scale: float, is_result: bool, self_or_result: Tensor): - negcoef = alpha * scale - poscoef = scale - negiptcoef = input_scale - if is_result: - return aten.where(self_or_result <= 0, grad_output * negiptcoef * (self_or_result + negcoef), self_or_result * poscoef) - else: - return aten.where(self_or_result <= 0, grad_output * negiptcoef * negcoef * aten.exp(self_or_result * negiptcoef), grad_output * poscoef) - -@register_decomposition(aten.hardsigmoid_backward) -def hardsigmoid_backward_decomposition(grad_output: Tensor, self: Tensor): - return aten.where((self > -3.0) & (self < 3.0), grad_output * (1.0/6.0), aten.new_zeros(grad_output, ())) - -@register_decomposition(aten.hardtanh_backward) -def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val: float, max_val: float): - return aten.where((self <= min_val) | (self >= max_val), aten.new_zeros(grad_output, ()), grad_output) - -@register_decomposition(aten.hardshrink_backward) -def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float): - return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out) - -@register_decomposition(aten.threshold_backward) -def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float): - return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output) - -@register_decomposition(aten.leaky_relu_backward) -def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool): - return aten.where(self > 0, grad_output, grad_output * negative_slope) - -@register_decomposition(aten.mse_loss_backward) -def mse_loss_backward_decomposition(grad_output: Tensor, input: Tensor, target: Tensor, reduction: int): - norm = 2./input.numel() if reduction == Reduction.MEAN.value else 2. - return norm * (input - target) * grad_output - -@register_decomposition(aten.huber_loss_backward) -def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float): - norm = 1./self.numel() if reduction == Reduction.MEAN.value else 1. - x = self - target - return aten.where(x < -delta, -norm * grad_output * delta, aten.where(x > delta, norm * grad_output * delta, norm * x * grad_output)) - -# @register_decomposition(aten._fused_dropout) -# def _fused_dropout_decomposition(input, p, generator=None): -# mask = aten.to(aten.rand_like(input) < p, dtype=torch.uint8) -# res = mask.type_as(input) * input * (1./p) -# return [res, mask] - -# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced. -@register_decomposition(aten.detach) -def detach_decomposition(x: Tensor): - return x - -@register_decomposition(aten._s_where) -def _s_where_canonicalization(a, b, c): - return aten.where(a, b, c) USE_DECOMPOSE = False diff --git a/functorch/functorch/compile/__init__.py b/functorch/functorch/compile/__init__.py index 5745ed16a5e8..585d98b80d45 100644 --- a/functorch/functorch/compile/__init__.py +++ b/functorch/functorch/compile/__init__.py @@ -1,6 +1,7 @@ from .._src.operator_authoring import pointwise_operator from .._src.memory_efficient_op_authoring import memory_efficient_operator_authoring, torchscript_nvfuser_compile -from .._src.python_key import nnc_jit, make_nnc, register_decomposition, pythonkey_decompose, decomposition_table +from .._src.python_key import nnc_jit, make_nnc, pythonkey_decompose +from .._src.decompositions import register_decomposition, decomposition_table from .._src.nnc_compile import nnc_compile, get_ops from .._src.aot_autograd import ( compiled_function, diff --git a/functorch/op_analysis/run_decompositions.txt b/functorch/op_analysis/run_decompositions.txt deleted file mode 100644 index e47a76641cb6..000000000000 --- a/functorch/op_analysis/run_decompositions.txt +++ /dev/null @@ -1,11 +0,0 @@ -_s_where -elu_backward -hardshrink_backward -hardsigmoid_backward -hardtanh_backward -huber_loss_backward -leaky_relu_backward -mse_loss_backward -sigmoid_backward -softplus_backward -tanh_backward diff --git a/functorch/op_analysis/run_ops.txt b/functorch/op_analysis/run_ops.txt deleted file mode 100644 index 56b4a3b35168..000000000000 --- a/functorch/op_analysis/run_ops.txt +++ /dev/null @@ -1,364 +0,0 @@ -__and__ -__or__ -_adaptive_avg_pool2d -_adaptive_avg_pool2d_backward -_adaptive_avg_pool3d -_adaptive_avg_pool3d_backward -_cdist_backward -_cdist_forward -_conj -_conj_physical -_det_lu_based_helper -_euclidean_dist -_fft_c2c -_fft_c2r -_fft_r2c -_histogramdd_bin_edges -_histogramdd_from_bin_cts -_histogramdd_from_bin_tensors -_local_scalar_dense -_log_softmax -_log_softmax_backward_data -_lu_with_info -_reshape_alias -_slow_conv2d_backward -_slow_conv2d_forward -_softmax -_softmax_backward_data -_svd_helper -_to_copy -_unique2 -_unsafe_view -abs -acos -acosh -adaptive_max_pool2d -adaptive_max_pool2d_backward -adaptive_max_pool3d -adaptive_max_pool3d_backward -add -add_ -addbmm -addcdiv -addcmul -addcmul_ -addmm -addmv -addr -alias -all -amax -amin -aminmax -angle -any -argmax -argmin -asin -asinh -atan -atan2 -atanh -avg_pool2d -avg_pool2d_backward -avg_pool3d -avg_pool3d_backward -baddbmm -bernoulli_ -bincount -bitwise_and -bitwise_and_ -bitwise_left_shift -bitwise_not -bitwise_or -bitwise_or_ -bitwise_right_shift -bitwise_xor -bmm -bucketize -cat -ceil -ceil_ -celu -cholesky -cholesky_inverse -cholesky_solve -clamp -clamp_ -clamp_min -clamp_min_ -clone -complex -constant_pad_nd -copy_ -copysign -cos -cosh -count_nonzero -cummax -cummin -cumprod -cumsum -deg2rad -detach -diag -diagonal -diagonal_scatter -digamma -dist -div -div_ -dot -eig -embedding -empty_like -eq -erf -erf_ -erfc -erfinv -exp -exp2 -exp_ -expand -expm1 -fill_ -flip -floor -floor_divide -fmax -fmin -fmod -frac -frexp -full_like -gather -ge -gelu -geqrf -grid_sampler_2d -grid_sampler_2d_backward -grid_sampler_3d -grid_sampler_3d_backward -gt -hardshrink -hardsigmoid -hardswish -hardswish_backward -hardtanh -histogram -huber_loss -hypot -i0 -igamma -igammac -im2col -index -index_add_ -index_copy_ -index_fill_ -index_put_ -index_select -inverse -isin -isnan -isneginf -isposinf -kthvalue -le -leaky_relu -lerp -lerp_ -lgamma -linalg_cholesky_ex -linalg_cross -linalg_eig -linalg_eigh -linalg_eigvalsh -linalg_householder_product -linalg_inv_ex -linalg_lstsq -linalg_matrix_exp -linalg_pinv -linalg_qr -linalg_slogdet -linalg_solve -linalg_vector_norm -log -log10 -log1p -log2 -log_sigmoid_backward -log_sigmoid_forward -logaddexp -logaddexp2 -logcumsumexp -logdet -logical_and -logical_not -logical_or -logical_xor -logit -logsumexp -lt -lu_solve -lu_unpack -masked_fill_ -masked_scatter_ -masked_select -max -max_pool2d_with_indices -max_pool2d_with_indices_backward -maximum -mean -median -min -minimum -mish -mkldnn_convolution -mkldnn_convolution_backward -mm -mode -mse_loss -mul -mul_ -mv -nan_to_num -nanmedian -nansum -native_batch_norm -native_batch_norm_backward -native_group_norm -native_layer_norm -native_layer_norm_backward -ne -neg -new_empty -nextafter -nll_loss2d_backward -nll_loss2d_forward -nll_loss_backward -nll_loss_forward -nonzero -norm -ones_like -ormqr -permute -polar -polygamma -pow -pow_ -prod -put_ -rad2deg -randint_like -randn_like -reciprocal -reciprocal_ -reflection_pad1d -reflection_pad1d_backward -reflection_pad2d -reflection_pad2d_backward -reflection_pad3d -reflection_pad3d_backward -relu -remainder -renorm -repeat -repeat_interleave -replication_pad1d -replication_pad1d_backward -replication_pad2d -replication_pad2d_backward -replication_pad3d -replication_pad3d_backward -resize_as_ -roll -rot90 -round -rsqrt -rsub -scatter -scatter_ -scatter_add -scatter_add_ -searchsorted -select -select_backward -select_scatter -sgn -sigmoid -sign -signbit -sin -sinc -sinh -slice -slice_backward -slice_scatter -slow_conv_dilated2d -slow_conv_dilated2d_backward -slow_conv_transpose2d -slow_conv_transpose2d_backward -softplus -solve -sort -special_entr -special_erfcx -special_i0e -special_i1 -special_i1e -special_ndtri -special_xlog1py -special_zeta -split -split_with_sizes -sqrt -sqrt_ -squeeze -squeeze_ -stack -std -std_mean -sub -sub_ -sum -symeig -t -take -tan -tanh -threshold -threshold_backward -to_sparse -topk -trace -transpose -triangular_solve -tril -tril_ -triu -triu_ -trunc -unbind -unfold -unique_consecutive -unique_dim -unsqueeze -unsqueeze_ -upsample_bicubic2d -upsample_bilinear2d -upsample_linear1d -upsample_nearest1d -upsample_nearest2d -upsample_nearest3d -upsample_trilinear3d -var -var_mean -vdot -view -view_as_real -where -xlogy -zero_ -zeros_like