[functorch] moved some stuff around

This commit is contained in:
Horace He
2021-11-30 19:36:39 +00:00
committed by Jon Janzen
parent f3946010cd
commit 653e56b6b0

View File

@ -42,11 +42,6 @@ def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
return out_grad * (y * (1 - y))
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
@register_decomposition(aten.detach)
def detach_decomposition(x: Tensor):
return x
@register_decomposition(aten.softplus_backward)
# The out argument seems to always be ignored?
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
@ -75,9 +70,9 @@ def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val:
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
# @register_decomposition(aten.threshold_backward)
# def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
# return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.threshold_backward)
def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
@register_decomposition(aten.leaky_relu_backward)
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
@ -100,6 +95,11 @@ def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target:
# res = mask.type_as(input) * input * (1./p)
# return [res, mask]
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
@register_decomposition(aten.detach)
def detach_decomposition(x: Tensor):
return x
@register_decomposition(aten._s_where)
def _s_where_canonicalization(a, b, c):
return aten.where(a, b, c)