mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] moved some stuff around
This commit is contained in:
@ -42,11 +42,6 @@ def tanh_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||
def sigmoid_backward_decomposition(out_grad: Tensor, y: Tensor):
|
||||
return out_grad * (y * (1 - y))
|
||||
|
||||
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
|
||||
@register_decomposition(aten.detach)
|
||||
def detach_decomposition(x: Tensor):
|
||||
return x
|
||||
|
||||
@register_decomposition(aten.softplus_backward)
|
||||
# The out argument seems to always be ignored?
|
||||
def softplus_backward_decomposition(out_grad: Tensor, x: Tensor, beta: float, threshold: float, out):
|
||||
@ -75,9 +70,9 @@ def hardtanh_backward_decomposition(grad_output: Tensor, self: Tensor, min_val:
|
||||
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
|
||||
return aten.where((self >= -lambd) & (self <= lambd), aten.new_zeros(grad_out, ()), grad_out)
|
||||
|
||||
# @register_decomposition(aten.threshold_backward)
|
||||
# def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
|
||||
# return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
|
||||
@register_decomposition(aten.threshold_backward)
|
||||
def threshold_backward_decomposition(grad_output: Tensor, self: Tensor, threshold: float):
|
||||
return aten.where(self <= threshold, aten.new_zeros(grad_output, ()), grad_output)
|
||||
|
||||
@register_decomposition(aten.leaky_relu_backward)
|
||||
def leaky_relu_backward(grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool):
|
||||
@ -100,6 +95,11 @@ def huber_loss_backward_decomposition(grad_output: Tensor, self: Tensor, target:
|
||||
# res = mask.type_as(input) * input * (1./p)
|
||||
# return [res, mask]
|
||||
|
||||
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
|
||||
@register_decomposition(aten.detach)
|
||||
def detach_decomposition(x: Tensor):
|
||||
return x
|
||||
|
||||
@register_decomposition(aten._s_where)
|
||||
def _s_where_canonicalization(a, b, c):
|
||||
return aten.where(a, b, c)
|
||||
|
Reference in New Issue
Block a user