mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: - Move batch norm from TH(CU)NN to native - Speedups in many cases (e.g. #12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode) - It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation. - We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision. Compared to the ill-fated #12368 - I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own) - I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`. - I have merged master Pull Request resolved: https://github.com/pytorch/pytorch/pull/13263 Differential Revision: D12946254 Pulled By: SsnL fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
1363 lines
62 KiB
YAML
1363 lines
62 KiB
YAML
# Defines derivative formulas and Python signatures of methods on Variable
|
|
#
|
|
# Each entry consists of:
|
|
# - A 'name', which specifies the ATen name of the function you
|
|
# are defining derivatives for, and an argument specification.
|
|
# - One or more gradients entries, mapping a differentiable input
|
|
# names to a formula specifying how to compute its gradient.
|
|
# Note that a single gradient entry can specify the gradient
|
|
# formula for multiple input names, by specifying a key
|
|
# "input1, input2" (see atan2 for an example).
|
|
# - Optional entry with key 'output_differentiability' and value a list of the
|
|
# same length as the number of outputs from the forward function. The list
|
|
# should contain only booleans, specifying whether each of the output Tensor
|
|
# is differentiable.
|
|
# If None of the output is differentiable, you can also add the function
|
|
# name to `gen_variable_type.py`'s `DONT_REQUIRE_DERIVATIVE` list.
|
|
#
|
|
# If a function has out-of-place and in-place variants, then the derivative
|
|
# definition for the in-place variant is optional. It will default to the
|
|
# definition for the out-of-place variant. Similarly, _out variants will
|
|
# default to the derivative for the non _out variant.
|
|
#
|
|
# Gradient expressions are standard C++ expressions operating on ATen
|
|
# variables. In a gradient expression, the following variables are in
|
|
# scope:
|
|
#
|
|
# - 'grad', the gradient of the output (often spelled grad_output
|
|
# in Python) which we are going to left-multiply.
|
|
#
|
|
# When a function returns multiple *differentiable* outputs,
|
|
# you can refer to the gradients of each outputs using 'grads',
|
|
# e.g., 'grads[0]', 'grads[1]'.
|
|
#
|
|
# When a function returns *one* differentiable output (the
|
|
# first output) and some more nondifferentiable outputs,
|
|
# you MUST refer to the gradient of the differentiable output with
|
|
# 'grad' (this case is special-cased in our code generation).
|
|
#
|
|
# Note that the number of differentibale outputs can be modified by the
|
|
# 'output_differentiability' entry (see above).
|
|
#
|
|
# - Any of the input arguments, tensor or non-tensor, including
|
|
# argument names that only appear in Declarations.cwrap, e.g. 'output'.
|
|
#
|
|
# - 'result', representing the result of evaluating the forward
|
|
# expression for ATen native function declarations. If the forward
|
|
# expression outputs a tuple, use 'resultX' instead to access the
|
|
# X-th entry
|
|
#
|
|
# - 'grad_input_mask', a std::array<bool, n>, specifies which input
|
|
# gradients are actually needed. For example, in the entry
|
|
# `input0, input1: foo(grad_input_mask)`, `grad_input_mask` is a size
|
|
# two array, where `grad_input_mask[0]` is true if `input0` requires
|
|
# grad, and `grad_input_mask[1]` is true if `input1` requires grad.
|
|
#
|
|
# (NB: if your function computes gradient for a list of tensors,
|
|
# the `grad_input_mask` will only have a single entry for the list
|
|
# specifying if either zero or at least one tensor from the list requires
|
|
# grad. If we want to support more fine-grained signalling,
|
|
# we'll need some alternate variable which is not a std::array)
|
|
#
|
|
# - 'retain_variables', a bool which is true if a user has specified
|
|
# that saved variables should be retained in case the backwards is
|
|
# run again later. This allows an optimization where we can
|
|
# destroy saved buffers if we know variables are not going to be retained,
|
|
# e.g., it is used by _cudnn_rnn
|
|
#
|
|
# If you need a complex expression, e.g., with local variables,
|
|
# write a _backward function in tools/autograd/templates/Functions.cpp
|
|
# and invoke it from here. By the way, go read
|
|
# https://github.com/zdevito/ATen/issues/163; this describes an
|
|
# important hazard that occurs when porting backwards from Python to C++
|
|
#
|
|
# Double backwards gradient expressions can be somewhat confusing;
|
|
# the most important thing to remember is: (1) you need to define a
|
|
# derivative formula for every input, including inputs named things
|
|
# like 'grad_output', and (2) the gradient to multiply with is always
|
|
# called 'grad' (even though it really is a grad-grad).
|
|
#
|
|
# NB: There are a number of gradient definitions in here which are bogus
|
|
# (implemented using zeros_like). These gradients are (hopefully) not
|
|
# used by our frontend. You MUST check the frontend code; search for
|
|
# OpName.apply to see if it's still using a legacy Python style API.
|
|
#
|
|
# NB: The parameter names here MUST be consistent with the parameter names
|
|
# in ./torch/lib/ATen/Declarations.cwrap
|
|
- name: abs(Tensor self)
|
|
self: grad * self.sign()
|
|
|
|
- name: acos(Tensor self)
|
|
self: grad * -((-self * self + 1).rsqrt())
|
|
|
|
- name: add(Tensor self, Tensor other, *, Scalar alpha)
|
|
self: grad
|
|
other: maybe_multiply(grad, alpha)
|
|
|
|
- name: add(Tensor self, Scalar other, *, Scalar alpha)
|
|
self: grad
|
|
|
|
- name: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha)
|
|
self: maybe_multiply(grad, beta)
|
|
batch1: grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) }).bmm(batch2.transpose(1, 2)) * alpha
|
|
batch2: batch1.transpose(1, 2).bmm(grad.unsqueeze(0).expand({ batch1.size(0), batch1.size(1), batch2.size(2) })) * alpha
|
|
|
|
- name: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value)
|
|
self: grad
|
|
tensor1: grad * value / tensor2
|
|
tensor2: -grad * value * tensor1 / (tensor2 * tensor2)
|
|
|
|
- name: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value)
|
|
self: grad
|
|
tensor1: grad * tensor2 * value
|
|
tensor2: grad * tensor1 * value
|
|
|
|
- name: _th_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha)
|
|
self: maybe_multiply(grad, beta)
|
|
mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)
|
|
mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)
|
|
|
|
- name: s_native_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha)
|
|
self: maybe_multiply(grad, beta)
|
|
mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)
|
|
mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)
|
|
|
|
- name: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha)
|
|
self: maybe_multiply(grad, beta)
|
|
mat: grad.ger(vec) * alpha
|
|
vec: mat.t().mv(grad) * alpha
|
|
|
|
- name: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha)
|
|
self: maybe_multiply(grad, beta)
|
|
vec1: grad.mv(vec2) * alpha
|
|
vec2: grad.t().mv(vec1) * alpha
|
|
|
|
- name: alias(Tensor self)
|
|
self: grad
|
|
|
|
- name: as_strided(Tensor self, IntList size, IntList stride, int64_t storage_offset)
|
|
self: as_strided_backward(grad, TensorGeometry(self), size, stride, storage_offset)
|
|
|
|
- name: asin(Tensor self)
|
|
self: grad * (-self * self + 1).rsqrt()
|
|
|
|
- name: atan(Tensor self)
|
|
self: grad / (self * self + 1)
|
|
|
|
- name: atan2(Tensor self, Tensor other)
|
|
self, other: atan2_backward(grad, self, other, grad_input_mask)
|
|
|
|
- name: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha)
|
|
self: maybe_multiply(grad, beta)
|
|
batch1: grad.bmm(batch2.transpose(1, 2)) * alpha
|
|
batch2: batch1.transpose(1, 2).bmm(grad) * alpha
|
|
|
|
- name: bernoulli(Tensor self, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: bernoulli_(Tensor self, Tensor p, Generator generator)
|
|
self: zeros_like(grad)
|
|
p: zeros_like(p)
|
|
|
|
- name: bernoulli_(Tensor self, double p, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: bmm(Tensor self, Tensor mat2)
|
|
self: grad.bmm(mat2.transpose(1, 2))
|
|
mat2: self.transpose(1, 2).bmm(grad)
|
|
|
|
- name: btrifact(Tensor self, bool pivot)
|
|
self: not_implemented("btrifact")
|
|
|
|
- name: btrifact_with_info(Tensor self, bool pivot)
|
|
self: not_implemented("btrifact_with_info")
|
|
|
|
- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots)
|
|
self: not_implemented("btrisolve")
|
|
|
|
- name: cat(TensorList tensors, int64_t dim)
|
|
tensors: cat_tensors_backward(grad, to_args_sizes(tensors), dim)
|
|
|
|
- name: cauchy_(Tensor self, double median, double sigma, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: ceil(Tensor self)
|
|
self: zeros_like(grad)
|
|
|
|
- name: cholesky(Tensor self, bool upper)
|
|
self: cholesky_backward(grad, upper, result)
|
|
|
|
# For clamp, gradient is not defined at the boundaries. But empirically it's helpful
|
|
# to be able to get gradient on min and max, so we return the subgradient 1 for these cases.
|
|
- name: clamp(Tensor self, Scalar? min, Scalar? max)
|
|
self: clamp_backward(grad, self, min, max)
|
|
|
|
- name: clamp_min(Tensor self, Scalar min)
|
|
self: grad * (self >= min).type_as(grad)
|
|
|
|
- name: clamp_max(Tensor self, Scalar max)
|
|
self: grad * (self <= max).type_as(grad)
|
|
|
|
- name: clone(Tensor self)
|
|
self: grad
|
|
|
|
- name: coalesce(Tensor self)
|
|
self: grad
|
|
|
|
- name: cos(Tensor self)
|
|
self: grad * -self.sin()
|
|
|
|
- name: cosh(Tensor self)
|
|
self: grad * self.sinh()
|
|
|
|
- name: cross(Tensor self, Tensor other, int64_t dim)
|
|
self: other.cross(grad, dim)
|
|
other: grad.cross(self, dim)
|
|
|
|
- name: cumprod(Tensor self, int64_t dim)
|
|
self: cumprod_backward(grad, self, dim)
|
|
|
|
- name: cumprod(Tensor self, int64_t dim, *, ScalarType dtype)
|
|
self: cumprod_backward(grad, self, dim, dtype)
|
|
|
|
- name: cumsum(Tensor self, int64_t dim)
|
|
self: cumsum_backward(grad, dim)
|
|
|
|
- name: cumsum(Tensor self, int64_t dim, *, ScalarType dtype)
|
|
self: cumsum_backward(grad, dim, self.type().scalarType())
|
|
|
|
- name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int64_t pad)
|
|
self, weight, bias: conv_tbc_backward(grad, self, weight, bias, pad)
|
|
|
|
- name: _ctc_loss(Tensor log_probs, Tensor targets, IntList input_lengths, IntList target_lengths, int64_t blank)
|
|
log_probs: _ctc_loss_backward(grad, log_probs, targets, input_lengths, target_lengths, result0, result1, blank)
|
|
|
|
- name: det(Tensor self)
|
|
self: det_backward(grad, self, result)
|
|
|
|
- name: diag(Tensor self, int64_t diagonal)
|
|
self: diag_backward(grad, self.sizes(), diagonal)
|
|
|
|
- name: diagonal(Tensor self, int64_t offset, int64_t dim1, int64_t dim2)
|
|
self: diagonal_backward(grad, self.sizes(), offset, dim1, dim2)
|
|
|
|
- name: dist(Tensor self, Tensor other, Scalar p)
|
|
self: norm_backward(grad, self - other, p, result)
|
|
other: -norm_backward(grad, self - other, p, result)
|
|
|
|
- name: div(Tensor self, Tensor other)
|
|
self: grad / other
|
|
other: -grad * self / (other * other)
|
|
|
|
- name: div(Tensor self, Scalar other)
|
|
self: grad / other
|
|
|
|
- name: dot(Tensor self, Tensor tensor)
|
|
self: grad * tensor
|
|
tensor: grad * self
|
|
|
|
- name: _fused_dropout(Tensor self, double p, Generator generator)
|
|
self: _fused_dropout_backward(grad, result1, p)
|
|
|
|
- name: eig(Tensor self, bool eigenvectors)
|
|
self: not_implemented("eig")
|
|
|
|
- name: eq_(Tensor self, Scalar other)
|
|
self: zeros_like(self)
|
|
|
|
- name: eq_(Tensor self, Tensor other)
|
|
self: zeros_like(self)
|
|
other: zeros_like(other)
|
|
|
|
- name: erf(Tensor self)
|
|
self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
|
|
|
|
- name: erfc(Tensor self)
|
|
self: -2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad
|
|
|
|
- name: erfinv(Tensor self)
|
|
self: 0.5 * sqrt(M_PI) * exp(self.erfinv().pow(2)) * grad
|
|
|
|
- name: exp(Tensor self)
|
|
self: grad * result
|
|
|
|
- name: expm1(Tensor self)
|
|
self: grad * (result + 1)
|
|
|
|
- name: expand(Tensor self, IntList size, *, bool implicit)
|
|
self: at::sum_to(grad, self.sizes())
|
|
|
|
- name: exponential_(Tensor self, double lambd, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: fill_(Tensor self, Scalar value)
|
|
self: zeros_like(grad)
|
|
|
|
- name: fill_(Tensor self, Tensor value)
|
|
self: zeros_like(grad)
|
|
value: grad.sum()
|
|
|
|
- name: floor(Tensor self)
|
|
self: zeros_like(grad)
|
|
|
|
- name: fmod(Tensor self, Scalar other)
|
|
self: grad
|
|
|
|
- name: fmod(Tensor self, Tensor other)
|
|
self: grad
|
|
other: 'not_implemented("fmod: other")'
|
|
|
|
- name: frac(Tensor self)
|
|
self: grad
|
|
|
|
- name: gather(Tensor self, int64_t dim, Tensor index)
|
|
self: at::zeros(self.sizes(), grad.options()).scatter_add_(dim, index, grad)
|
|
|
|
- name: ge_(Tensor self, Scalar other)
|
|
self: zeros_like(self)
|
|
|
|
- name: ge_(Tensor self, Tensor other)
|
|
self: zeros_like(self)
|
|
other: zeros_like(other)
|
|
|
|
- name: gels(Tensor self, Tensor A)
|
|
self: not_implemented("gels")
|
|
A: not_implemented("gels")
|
|
|
|
- name: geometric_(Tensor self, double p, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: geqrf(Tensor self)
|
|
self: not_implemented("geqrf")
|
|
|
|
- name: ger(Tensor self, Tensor vec2)
|
|
self: grad.mv(vec2)
|
|
vec2: grad.t().mv(self)
|
|
|
|
- name: gesv(Tensor self, Tensor A)
|
|
self: gesv_backward_self(grad, self, A)
|
|
A: gesv_backward_A(grad, self, A, result0)
|
|
|
|
- name: indices(Tensor self)
|
|
output_differentiability: [False]
|
|
|
|
- name: _indices(Tensor self)
|
|
output_differentiability: [False]
|
|
|
|
- name: grid_sampler_2d(Tensor input, Tensor grid, int64_t interpolation_mode, int64_t padding_mode)
|
|
input, grid: grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode)
|
|
|
|
- name: grid_sampler_3d(Tensor input, Tensor grid, int64_t interpolation_mode, int64_t padding_mode)
|
|
input, grid: grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode)
|
|
|
|
- name: gt_(Tensor self, Scalar other)
|
|
self: zeros_like(self)
|
|
|
|
- name: gt_(Tensor self, Tensor other)
|
|
self: zeros_like(self)
|
|
other: zeros_like(other)
|
|
|
|
- name: histc(Tensor self, int64_t bins, Scalar min, Scalar max)
|
|
self: not_implemented("histc")
|
|
|
|
- name: index_add_(Tensor self, int64_t dim, Tensor index, Tensor source)
|
|
self: grad
|
|
source: grad.index_select(dim, index)
|
|
|
|
- name: index_copy_(Tensor self, int64_t dim, Tensor index, Tensor source)
|
|
self: grad.clone().index_fill_(dim, index, 0)
|
|
source: grad.index_select(dim, index)
|
|
|
|
- name: index_fill_(Tensor self, int64_t dim, Tensor index, Scalar value)
|
|
self: grad.clone().index_fill_(dim, index, 0)
|
|
|
|
- name: index_fill_(Tensor self, int64_t dim, Tensor index, Tensor value)
|
|
self: grad.clone().index_fill_(dim, index, 0)
|
|
value: grad.index_select(dim, index).sum()
|
|
|
|
- name: index_select(Tensor self, int64_t dim, Tensor index)
|
|
self: at::zeros(self.sizes(), grad.options()).index_add_(dim, index, grad)
|
|
|
|
- name: inverse(Tensor self)
|
|
self: -at::matmul(result.transpose(-2, -1), at::matmul(grad, result.transpose(-2, -1)))
|
|
|
|
- name: kthvalue(Tensor self, int64_t k, int64_t dim, bool keepdim)
|
|
self: index_select_backward(grad, dim, result1, self.sizes(), keepdim)
|
|
|
|
- name: le_(Tensor self, Scalar other)
|
|
self: zeros_like(self)
|
|
|
|
- name: le_(Tensor self, Tensor other)
|
|
self: zeros_like(self)
|
|
other: zeros_like(other)
|
|
|
|
- name: lerp(Tensor self, Tensor end, Scalar weight)
|
|
self: grad * (1 - weight.toDouble())
|
|
end: grad * weight
|
|
|
|
- name: lgamma(Tensor self)
|
|
self: grad * digamma(self)
|
|
|
|
- name: digamma(Tensor self)
|
|
self: grad * polygamma(1, self)
|
|
|
|
- name: polygamma(int64_t n, Tensor self)
|
|
self: grad * polygamma(n + 1, self)
|
|
|
|
- name: log(Tensor self)
|
|
self: grad.div(self)
|
|
|
|
- name: log10(Tensor self)
|
|
self: grad / (self * 2.3025850929940456)
|
|
|
|
- name: log1p(Tensor self)
|
|
self: log1p_backward(grad, self)
|
|
|
|
- name: log2(Tensor self)
|
|
self: grad / (self * 0.6931471805599453)
|
|
|
|
- name: logdet(Tensor self)
|
|
self: logdet_backward(grad, self, result)
|
|
|
|
- name: log_normal_(Tensor self, double mean, double std, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: logsumexp(Tensor self, int64_t dim, bool keepdim)
|
|
self: logsumexp_backward(grad, self, result, dim, keepdim)
|
|
|
|
- name: lt_(Tensor self, Scalar other)
|
|
self: zeros_like(self)
|
|
|
|
- name: lt_(Tensor self, Tensor other)
|
|
self: zeros_like(self)
|
|
other: zeros_like(other)
|
|
|
|
- name: masked_fill_(Tensor self, Tensor mask, Scalar value)
|
|
self: grad.clone().masked_fill_(mask, 0)
|
|
|
|
- name: masked_fill_(Tensor self, Tensor mask, Tensor value)
|
|
self: grad.clone().masked_fill_(mask, 0)
|
|
value: at::where(mask, grad, zeros_like(grad)).sum()
|
|
|
|
- name: masked_scatter_(Tensor self, Tensor mask, Tensor source)
|
|
self: grad.clone().masked_fill_(mask, 0)
|
|
source: masked_scatter_backward(grad, mask, source.sizes())
|
|
|
|
- name: masked_select(Tensor self, Tensor mask)
|
|
# normally broadcasting is handled implicitly, but here, because we call an inplace
|
|
# function as an optimization and the LHS doesn't broadcast for inplace functions,
|
|
# we need to explicitly broadcast.
|
|
self: zeros_like(self.expand(at::infer_size(self.sizes(), mask.sizes()))).masked_scatter_(mask, grad)
|
|
|
|
- name: max(Tensor self, int64_t dim, bool keepdim)
|
|
self: index_select_backward(grad, dim, result1, self.sizes(), keepdim)
|
|
|
|
- name: max(Tensor self)
|
|
self: select_equals_backward(grad, self, result)
|
|
|
|
- name: max(Tensor self, Tensor other)
|
|
self: grad.clone().masked_fill_(self <= other, 0)
|
|
other: grad.clone().masked_fill_(self > other, 0)
|
|
|
|
- name: mean(Tensor self, int64_t dim, bool keepdim)
|
|
self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim)
|
|
|
|
- name: mean(Tensor self)
|
|
self: grad.expand(self.sizes()) / self.numel()
|
|
|
|
- name: median(Tensor self)
|
|
self: select_equals_backward(grad, self, result)
|
|
|
|
# This is in theory incorrect in the following case:
|
|
# sorted list: [..., a, b, b, ..., b, b, c, ...] with median = b and the value
|
|
# | at middle position of the
|
|
# | list between two `b`s. E.g.,
|
|
# |
|
|
# ^the middle position
|
|
# The gradient exists and is essentially 0 in this case.
|
|
#
|
|
# In case where the middle position is at the boundary of `b` range, e.g.,
|
|
# sorted list: [..., a, b, b, ..., b, b, c, ...]
|
|
# |
|
|
# ^the middle position
|
|
# The backward implementation is correct in the sense that it returns the
|
|
# subgradient on one side.
|
|
- name: median(Tensor self, int64_t dim, bool keepdim)
|
|
self: index_select_backward(grad, dim, result1, self.sizes(), keepdim)
|
|
|
|
- name: min(Tensor self, int64_t dim, bool keepdim)
|
|
self: index_select_backward(grad, dim, result1, self.sizes(), keepdim)
|
|
|
|
- name: min(Tensor self)
|
|
self: select_equals_backward(grad, self, result)
|
|
|
|
- name: min(Tensor self, Tensor other)
|
|
self: grad.clone().masked_fill_(self >= other, 0)
|
|
other: grad.clone().masked_fill_(self < other, 0)
|
|
|
|
- name: mm(Tensor self, Tensor mat2)
|
|
self: mm_mat1_backward(grad, mat2, self, 1)
|
|
mat2: mm_mat2_backward(grad, self, mat2.sizes(), mat2.strides(), 1)
|
|
|
|
- name: mode(Tensor self, int64_t dim, bool keepdim)
|
|
self: index_select_backward(grad, dim, result1, self.sizes(), keepdim)
|
|
|
|
- name: mul(Tensor self, Tensor other)
|
|
self: grad * other
|
|
other: grad * self
|
|
|
|
- name: mul(Tensor self, Scalar other)
|
|
self: grad * other
|
|
|
|
- name: mv(Tensor self, Tensor vec)
|
|
self: grad.ger(vec)
|
|
vec: self.t().mv(grad)
|
|
|
|
- name: mvlgamma(Tensor self, int64_t p)
|
|
self: mvlgamma_backward(grad, self, p)
|
|
|
|
- name: native_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps)
|
|
input, weight, bias: native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, eps, grad_input_mask)
|
|
|
|
- name: native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_invstd, bool train, double eps, std::array<bool,3> output_mask)
|
|
input, weight, grad_out: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_out, running_mean, running_var, train, eps, save_mean, save_invstd, grad_input_mask)
|
|
save_mean: not_implemented("native_batch_norm_backward save_mean")
|
|
save_invstd: not_implemented("native_batch_norm_backward save_invstd")
|
|
|
|
- name: ne_(Tensor self, Scalar other)
|
|
self: zeros_like(self)
|
|
|
|
- name: ne_(Tensor self, Tensor other)
|
|
self: zeros_like(self)
|
|
other: zeros_like(other)
|
|
|
|
- name: neg(Tensor self)
|
|
self: grad.neg()
|
|
|
|
- name: norm(Tensor self, Scalar p)
|
|
self: norm_backward(grad, self, p, result)
|
|
|
|
- name: norm(Tensor self, Scalar p, int64_t dim, bool keepdim)
|
|
self: norm_backward(grad, self, p, result, dim, keepdim)
|
|
|
|
- name: _pdist_forward(Tensor self, double p)
|
|
self: _pdist_backward(grad, self, p, result)
|
|
|
|
- name: _pdist_backward(Tensor grad, Tensor self, double p, Tensor pdist)
|
|
grad: not_implemented("_pdist_backward")
|
|
self: not_implemented("_pdist_backward")
|
|
pdist: not_implemented("_pdist_backward")
|
|
|
|
- name: normal_(Tensor self, double mean, double std, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: normal(Tensor mean, double std, Generator generator)
|
|
mean: at::zeros(mean.sizes(), grad.options())
|
|
|
|
- name: normal(double mean, Tensor std, Generator generator)
|
|
std: at::zeros(std.sizes(), grad.options())
|
|
|
|
- name: normal(Tensor mean, Tensor std, Generator generator)
|
|
mean: at::zeros(mean.sizes(), grad.options())
|
|
std: at::zeros(std.sizes(), grad.options())
|
|
|
|
- name: orgqr(Tensor self, Tensor input2)
|
|
self: not_implemented("orgqr")
|
|
input2: not_implemented("orgqr")
|
|
|
|
- name: ormqr(Tensor self, Tensor input2, Tensor input3, bool left, bool transpose)
|
|
self: not_implemented("ormqr")
|
|
input2: not_implemented("ormqr")
|
|
input3: not_implemented("ormqr")
|
|
|
|
- name: permute(Tensor self, IntList dims)
|
|
self: permute_backwards(grad, dims)
|
|
|
|
- name: poisson(Tensor self, Generator generator)
|
|
self: zeros_like(self)
|
|
|
|
- name: potri(Tensor self, bool upper)
|
|
self: not_implemented("potri")
|
|
|
|
- name: potrs(Tensor self, Tensor input2, bool upper)
|
|
self: not_implemented("potri")
|
|
input2: not_implemented("potri")
|
|
|
|
- name: pow(Tensor self, Scalar exponent)
|
|
self: pow_backward(grad, self, exponent)
|
|
|
|
- name: pow(Tensor self, Tensor exponent)
|
|
self: pow_backward_self(grad, self, exponent)
|
|
exponent: pow_backward_exponent(grad, self, exponent)
|
|
|
|
- name: pow(Scalar self, Tensor exponent)
|
|
exponent: pow_backward_exponent(grad, self, exponent)
|
|
|
|
- name: prod(Tensor self)
|
|
self: prod_backward(grad, self, result)
|
|
|
|
- name: prod(Tensor self, ScalarType dtype)
|
|
self: prod_backward(grad, self.to(grad.type().scalarType()), result).to(self.type().scalarType())
|
|
|
|
- name: prod(Tensor self, int64_t dim, bool keepdim)
|
|
self: prod_backward(grad, self, result, dim, keepdim)
|
|
|
|
- name: prod(Tensor self, int64_t dim, ScalarType dtype)
|
|
self: prod_backward(grad, self.to(grad.type().scalarType()), result, dim, false).to(self.type().scalarType())
|
|
|
|
- name: prod(Tensor self, int64_t dim, bool keepdim, ScalarType dtype)
|
|
self: prod_backward(grad, self.to(grad.type().scalarType()), result, dim, keepdim).to(self.type().scalarType())
|
|
|
|
- name: pstrf(Tensor self, bool upper, Scalar tol)
|
|
self: not_implemented("pstrf")
|
|
|
|
- name: put_(Tensor self, Tensor index, Tensor source, bool accumulate)
|
|
self: grad.clone().put_(index, zeros_like(source), accumulate)
|
|
source: grad.take(index)
|
|
|
|
- name: qr(Tensor self)
|
|
self: not_implemented("qr")
|
|
|
|
- name: random_(Tensor self, int64_t from, int64_t to, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: random_(Tensor self, int64_t to, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: random_(Tensor self, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: reciprocal(Tensor self)
|
|
self: -grad * result * result
|
|
|
|
- name: remainder(Tensor self, Scalar other)
|
|
self: grad
|
|
|
|
- name: remainder(Tensor self, Tensor other)
|
|
self: grad
|
|
|
|
- name: renorm(Tensor self, Scalar p, int64_t dim, Scalar maxnorm)
|
|
self: renorm_backward(grad, self, p, dim, maxnorm)
|
|
|
|
- name: repeat(Tensor self, IntList repeats)
|
|
self: repeat_backward(grad, self.dim(), repeats)
|
|
|
|
# DO NOT define a backward for reshape!
|
|
# reshape is special in that it sometimes returns a view, and sometimes not.
|
|
# Defining a backward will make codegen spit out the forward call as
|
|
# as_variable(baseType->reshape(self)),
|
|
# making it impossible (hard) to detect when it is actually a view.
|
|
# - name: reshape(Tensor self, IntList shape)
|
|
|
|
- name: RoiPooling2d_forward(Tensor input, Tensor rois, int64_t pooledHeight, int64_t pooledWidth, double spatialScale)
|
|
input: RoiPooling2d_backward(input, rois, pooledHeight, pooledWidth, spatialScale, grad, result1)
|
|
|
|
- name: round(Tensor self)
|
|
self: zeros_like(grad)
|
|
|
|
- name: rsqrt(Tensor self)
|
|
self: -0.5 * grad * result.pow(3)
|
|
|
|
- name: scatter_(Tensor self, int64_t dim, Tensor index, Tensor src)
|
|
self: grad.clone().scatter_(dim, index, 0)
|
|
src: grad.gather(dim, index)
|
|
|
|
- name: scatter_(Tensor self, int64_t dim, Tensor index, Scalar value)
|
|
self: grad.clone().scatter_(dim, index, 0)
|
|
|
|
- name: scatter_add_(Tensor self, int64_t dim, Tensor index, Tensor src)
|
|
self: grad
|
|
src: grad.gather(dim, index)
|
|
|
|
- name: select(Tensor self, int64_t dim, int64_t index)
|
|
self: select_backward(grad, self.sizes(), dim, index)
|
|
|
|
- name: sigmoid(Tensor self)
|
|
self: _sigmoid_backward(grad, result)
|
|
|
|
- name: sign(Tensor self)
|
|
self: zeros_like(grad)
|
|
|
|
- name: sin(Tensor self)
|
|
self: grad * self.cos()
|
|
|
|
- name: sinh(Tensor self)
|
|
self: grad * self.cosh()
|
|
|
|
- name: slice(Tensor self, int64_t dim, int64_t start, int64_t end, int64_t step)
|
|
self: slice_backward(grad, self.sizes(), dim, start, end, step)
|
|
|
|
- name: slogdet(Tensor self)
|
|
self: slogdet_backward(grads, self, result0, result1)
|
|
|
|
- name: sort(Tensor self, int64_t dim, bool descending)
|
|
self: index_select_backward(grad, dim, result1, self.sizes(), true)
|
|
|
|
- name: split(Tensor self, int64_t split_size, int64_t dim)
|
|
self: split_backward(grads, split_size, dim, self.sizes(), self.type())
|
|
|
|
- name: split_with_sizes(Tensor self, IntList split_sizes, int64_t dim)
|
|
self: split_with_sizes_backward(grads, split_sizes, dim, self.sizes(), self.type())
|
|
|
|
- name: sqrt(Tensor self)
|
|
self: grad / (2 * result)
|
|
|
|
- name: squeeze(Tensor self)
|
|
self: unsqueeze_to(grad, self.sizes());
|
|
|
|
- name: squeeze(Tensor self, int64_t dim)
|
|
self: unsqueeze_to(grad, dim, self.sizes())
|
|
|
|
- name: squeeze_(Tensor self)
|
|
self: unsqueeze_to(grad, self.sizes());
|
|
|
|
- name: squeeze_(Tensor self, int64_t dim)
|
|
self: unsqueeze_to(grad, dim, self.sizes())
|
|
|
|
- name: std(Tensor self, bool unbiased)
|
|
self: var_backward(grad / (result * 2), self, unbiased)
|
|
|
|
- name: std(Tensor self, int64_t dim, bool unbiased, bool keepdim)
|
|
self: var_backward(grad / (result * 2), self, dim, unbiased, keepdim)
|
|
|
|
- name: sub(Tensor self, Tensor other, *, Scalar alpha)
|
|
self: grad
|
|
other: -grad * alpha
|
|
|
|
- name: sub(Tensor self, Scalar other, *, Scalar alpha)
|
|
self: grad
|
|
|
|
- name: rsub(Tensor self, Tensor other, *, Scalar alpha)
|
|
self: -grad * alpha
|
|
other: grad
|
|
|
|
- name: rsub(Tensor self, Scalar other, *, Scalar alpha)
|
|
self: -grad * alpha
|
|
|
|
- name: sum(Tensor self)
|
|
self: grad.expand(self.sizes())
|
|
|
|
- name: sum(Tensor self, ScalarType dtype)
|
|
self: grad.expand(self.sizes()).to(self.type().scalarType())
|
|
|
|
- name: sum(Tensor self, IntList dim, bool keepdim)
|
|
self: sum_backward(grad, self.sizes(), dim, keepdim)
|
|
|
|
- name: sum(Tensor self, IntList dim, ScalarType dtype)
|
|
self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType())
|
|
|
|
- name: sum(Tensor self, IntList dim, bool keepdim, ScalarType dtype)
|
|
self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType())
|
|
|
|
- name: svd(Tensor self, bool some, bool compute_uv)
|
|
self: svd_backward(grads, self, some, compute_uv, result0, result1, result2)
|
|
|
|
- name: symeig(Tensor self, bool eigenvectors, bool upper)
|
|
self: symeig_backward(grads, self, eigenvectors, upper, result0, result1)
|
|
|
|
- name: t(Tensor self)
|
|
self: grad.t()
|
|
|
|
- name: flip(Tensor self, IntList dims)
|
|
self: grad.flip(dims)
|
|
|
|
- name: roll(Tensor self, IntList shifts, IntList dims)
|
|
self: grad.roll( -shifts[0], dims)
|
|
|
|
- name: rot90(Tensor self, int64_t k, IntList dims)
|
|
self: grad.rot90(-k, dims)
|
|
|
|
- name: take(Tensor self, Tensor index)
|
|
self: zeros_like(self).put_(index, grad, true)
|
|
|
|
- name: tan(Tensor self)
|
|
self: grad * (1 + result.pow(2))
|
|
|
|
- name: tanh(Tensor self)
|
|
self: _tanh_backward(grad, result)
|
|
|
|
- name: topk(Tensor self, int64_t k, int64_t dim, bool largest, bool sorted)
|
|
self: index_select_backward(grad, dim, result1, self.sizes(), true)
|
|
|
|
- name: trace(Tensor self)
|
|
self: trace_backward(grad, self.sizes())
|
|
|
|
- name: transpose(Tensor self, int64_t dim0, int64_t dim1)
|
|
self: grad.transpose(dim0, dim1)
|
|
|
|
- name: transpose_(Tensor self, int64_t dim0, int64_t dim1)
|
|
self: grad.transpose(dim0, dim1)
|
|
|
|
- name: tril(Tensor self, int64_t diagonal)
|
|
self: grad.tril(diagonal)
|
|
|
|
- name: triu(Tensor self, int64_t diagonal)
|
|
self: grad.triu(diagonal)
|
|
|
|
- name: trtrs(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular)
|
|
self, A: trtrs_backward(grads[0], grads[1], self, A, result0, upper, transpose, unitriangular, grad_input_mask)
|
|
|
|
- name: trunc(Tensor self)
|
|
self: zeros_like(grad)
|
|
|
|
- name: unfold(Tensor self, int64_t dimension, int64_t size, int64_t step)
|
|
self: unfold_backward(grad, self.sizes(), dimension, size, step)
|
|
|
|
- name: uniform_(Tensor self, double from, double to, Generator generator)
|
|
self: zeros_like(grad)
|
|
|
|
- name: _unique(Tensor self, bool sorted, bool return_inverse)
|
|
self: not_implemented("_unique")
|
|
|
|
- name: _unsafe_view(Tensor self, IntList size)
|
|
self: grad.reshape(self.sizes())
|
|
|
|
- name: unsqueeze(Tensor self, int64_t dim)
|
|
self: grad.squeeze(dim)
|
|
|
|
- name: unsqueeze_(Tensor self, int64_t dim)
|
|
self: grad.squeeze(dim)
|
|
|
|
- name: var(Tensor self, bool unbiased)
|
|
self: var_backward(grad, self, unbiased)
|
|
|
|
- name: var(Tensor self, int64_t dim, bool unbiased, bool keepdim)
|
|
self: var_backward(grad, self, dim, unbiased, keepdim)
|
|
|
|
- name: view(Tensor self, IntList size)
|
|
self: grad.reshape(self.sizes())
|
|
|
|
- name: _s_where(Tensor condition, Tensor self, Tensor other)
|
|
self: where(condition, grad, zeros_like(grad))
|
|
other: where(condition, zeros_like(grad), grad)
|
|
|
|
# weight_norm_cuda_interface_backward does not have an explicitly defined derivative, so if we do happen
|
|
# to be running backward with create_graph=True, fall back to a backward function that uses
|
|
# differentiable ops.
|
|
- name: _weight_norm_cuda_interface(Tensor v, Tensor g, int64_t dim)
|
|
v, g: "GradMode::is_enabled() ? _weight_norm_differentiable_backward(grad.contiguous(), v, g, result1, dim) : _weight_norm_cuda_interface_backward(grad.contiguous(), v, g, result1, dim)"
|
|
|
|
- name: zero_(Tensor self)
|
|
self: zeros_like(grad)
|
|
|
|
- name: sparse_mask(Tensor self, SparseTensorRef mask)
|
|
self: not_implemented("sparse_mask")
|
|
mask: not_implemented("sparse_mask")
|
|
|
|
- name: _sparse_coo_tensor_with_dims_and_tensors(int64_t sparse_dim, int64_t dense_dim, IntList size, Tensor indices, Tensor values, TensorOptions options)
|
|
values: sparse_constructor_values_backward(grad, indices, values.sizes())
|
|
|
|
- name: _standard_gamma(Tensor self, Generator generator)
|
|
self: grad * _standard_gamma_grad(self, result)
|
|
|
|
- name: _standard_gamma_grad(Tensor self, Tensor output)
|
|
self: not_implemented("_standard_gamma_grad")
|
|
|
|
- name: values(Tensor self)
|
|
self: at::_sparse_coo_tensor_unsafe(self.indices(), grad, self.sizes())._coalesced_(true);
|
|
|
|
# Why is _values() not differentiable?
|
|
# See NOTE [ Sparse: autograd and API ]
|
|
- name: _values(Tensor self)
|
|
output_differentiability: [False]
|
|
|
|
# NN
|
|
- name: _trilinear(Tensor i1, Tensor i2, Tensor i3, IntList expand1, IntList expand2, IntList expand3, IntList sumdim, int64_t unroll_dim)
|
|
i1, i2, i3: _trilinear_backward(grad, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, grad_input_mask)
|
|
|
|
- name: constant_pad_nd(Tensor self, IntList pad, Scalar value)
|
|
self: constant_pad_nd_backward(grad, pad)
|
|
|
|
- name: binary_cross_entropy_forward(Tensor self, Tensor target, Tensor weight, int64_t reduction)
|
|
self: binary_cross_entropy_backward(grad, self, target, weight, reduction)
|
|
|
|
- name: binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor weight, Tensor pos_weight, int64_t reduction)
|
|
self: binary_cross_entropy_with_logits_backward(grad, self, target, weight, pos_weight, reduction)
|
|
target: binary_cross_entropy_with_logits_target_backward(grad, self, target, weight, pos_weight, reduction)
|
|
|
|
- name: embedding(Tensor weight, Tensor indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse)
|
|
weight: embedding_backward(grad, indices, weight.size(0), padding_idx, scale_grad_by_freq, sparse)
|
|
|
|
- name: _embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int64_t mode, bool sparse)
|
|
weight: _embedding_bag_backward(grad, indices, offsets, result1, result2, result3, weight.size(0), scale_grad_by_freq, mode, sparse)
|
|
|
|
- name: embedding_renorm_(Tensor self, Tensor indices, double max_norm, double norm_type)
|
|
self: not_implemented("embedding_renorm")
|
|
|
|
- name: kl_div(Tensor self, Tensor target, int64_t reduction)
|
|
self: kl_div_backward(grad, self, target, reduction)
|
|
target: kl_div_target_backward(grad, self, target, reduction)
|
|
|
|
- name: l1_loss_forward(Tensor self, Tensor target, int64_t reduction)
|
|
self: l1_loss_backward(grad, self, target, reduction)
|
|
|
|
- name: mse_loss_forward(Tensor self, Tensor target, int64_t reduction)
|
|
self: mse_loss_backward(grad, self, target, reduction)
|
|
|
|
- name: multi_margin_loss_forward(Tensor self, Tensor target, Scalar p, Scalar margin, Tensor weight, int64_t reduction)
|
|
self: multi_margin_loss_backward(grad, self, target, p, margin, weight, reduction)
|
|
|
|
- name: multilabel_margin_loss_forward(Tensor self, Tensor target, int64_t reduction)
|
|
self: multilabel_margin_loss_backward(grad, self, target, reduction, is_target)
|
|
|
|
- name: nll_loss_forward(Tensor self, Tensor target, Tensor weight, int64_t reduction, int64_t ignore_index)
|
|
self: nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight)
|
|
|
|
- name: nll_loss2d_forward(Tensor self, Tensor target, Tensor weight, int64_t reduction, int64_t ignore_index)
|
|
self: nll_loss2d_backward(grad, self, target, weight, reduction, ignore_index, total_weight)
|
|
|
|
- name: smooth_l1_loss_forward(Tensor self, Tensor target, int64_t reduction)
|
|
self: smooth_l1_loss_backward(grad, self, target, reduction)
|
|
|
|
- name: soft_margin_loss_forward(Tensor self, Tensor target, int64_t reduction)
|
|
self: soft_margin_loss_backward(grad, self, target, reduction)
|
|
|
|
- name: relu(Tensor self)
|
|
self: threshold_backward(grad, self, 0)
|
|
|
|
# NB: `output` instead of `self` saves memory. It avoids saving a copy of self.
|
|
- name: relu_(Tensor self)
|
|
self: threshold_backward(grad, output, 0)
|
|
|
|
- name: elu_forward(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale)
|
|
self: elu_backward(grad, alpha, scale, input_scale, output)
|
|
|
|
- name: glu_forward(Tensor self, int64_t dim)
|
|
self: glu_backward(grad, self, dim)
|
|
|
|
- name: hardshrink(Tensor self, Scalar lambd)
|
|
self: hardshrink_backward(grad, self, lambd)
|
|
|
|
- name: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd)
|
|
grad_out: hardshrink_backward(grad, self, lambd)
|
|
self: zeros_like(grad)
|
|
|
|
- name: hardtanh_forward(Tensor self, Scalar min_val, Scalar max_val)
|
|
self: hardtanh_backward(grad, self, min_val, max_val)
|
|
|
|
- name: hardtanh_forward_(Tensor self, Scalar min_val, Scalar max_val)
|
|
self: hardtanh_backward(grad, output, min_val, max_val)
|
|
|
|
- name: leaky_relu_forward(Tensor self, Scalar negative_slope)
|
|
self: leaky_relu_backward(grad, self, negative_slope)
|
|
|
|
- name: leaky_relu_forward_(Tensor self, Scalar negative_slope)
|
|
self: leaky_relu_backward(grad, output, negative_slope)
|
|
|
|
- name: log_sigmoid_forward(Tensor self)
|
|
self: log_sigmoid_backward(grad, self, buffer)
|
|
|
|
- name: _log_softmax(Tensor self, int64_t dim, bool half_to_float)
|
|
self: _log_softmax_backward_data(grad, result, dim, self)
|
|
|
|
- name: prelu(Tensor self, Tensor weight)
|
|
self, weight: prelu_backward(grad, self, weight)
|
|
|
|
- name: prelu_backward(Tensor grad_output, Tensor self, Tensor weight)
|
|
grad_output, self, weight: prelu_double_backward(grads[0], grads[1], grad_output, self, weight)
|
|
|
|
- name: rrelu_with_noise_forward(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator)
|
|
self: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
|
|
|
|
- name: rrelu_with_noise_forward_(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator)
|
|
self: rrelu_with_noise_backward(grad, output, noise, lower, upper, training)
|
|
|
|
- name: _softmax(Tensor self, int64_t dim, bool half_to_float)
|
|
self: _softmax_backward_data(grad, result, dim, self)
|
|
|
|
- name: softplus_forward(Tensor self, Scalar beta, Scalar threshold)
|
|
self: softplus_backward(grad, self, beta, threshold, output)
|
|
|
|
- name: softshrink_forward(Tensor self, Scalar lambd)
|
|
self: softshrink_backward(grad, self, lambd)
|
|
|
|
- name: threshold(Tensor self, Scalar threshold, Scalar value)
|
|
self: threshold_backward(grad, self, threshold)
|
|
|
|
- name: threshold_(Tensor self, Scalar threshold, Scalar value)
|
|
self: threshold_backward(grad, output, threshold)
|
|
|
|
- name: reflection_pad1d_forward(Tensor self, IntList padding)
|
|
self: reflection_pad1d_backward(grad, self, padding)
|
|
|
|
- name: reflection_pad2d_forward(Tensor self, IntList padding)
|
|
self: reflection_pad2d_backward(grad, self, padding)
|
|
|
|
- name: replication_pad1d_forward(Tensor self, IntList padding)
|
|
self: replication_pad1d_backward(grad, self, padding)
|
|
|
|
- name: replication_pad2d_forward(Tensor self, IntList padding)
|
|
self: replication_pad2d_backward(grad, self, padding)
|
|
|
|
- name: replication_pad3d_forward(Tensor self, IntList padding)
|
|
self: replication_pad3d_backward(grad, self, padding)
|
|
|
|
- name: upsample_linear1d_forward(Tensor self, IntList output_size, bool align_corners)
|
|
self: upsample_linear1d_backward(grad, output_size, self.sizes(), align_corners)
|
|
|
|
- name: upsample_bilinear2d_forward(Tensor self, IntList output_size, bool align_corners)
|
|
self: upsample_bilinear2d_backward(grad, output_size, self.sizes(), align_corners)
|
|
|
|
- name: upsample_trilinear3d_forward(Tensor self, IntList output_size, bool align_corners)
|
|
self: upsample_trilinear3d_backward(grad, output_size, self.sizes(), align_corners)
|
|
|
|
- name: upsample_nearest1d_forward(Tensor self, IntList output_size)
|
|
self: upsample_nearest1d_backward(grad, output_size, self.sizes())
|
|
|
|
- name: upsample_nearest2d_forward(Tensor self, IntList output_size)
|
|
self: upsample_nearest2d_backward(grad, output_size, self.sizes())
|
|
|
|
- name: upsample_nearest3d_forward(Tensor self, IntList output_size)
|
|
self: upsample_nearest3d_backward(grad, output_size, self.sizes())
|
|
|
|
- name: adaptive_avg_pool2d_forward(Tensor self, IntList output_size)
|
|
self: adaptive_avg_pool2d_backward(grad, self)
|
|
|
|
- name: adaptive_avg_pool3d_forward(Tensor self, IntList output_size)
|
|
self: adaptive_avg_pool3d_backward(grad, self)
|
|
|
|
- name: adaptive_max_pool2d_forward(Tensor self, IntList output_size)
|
|
self: adaptive_max_pool2d_backward(grad, self, indices)
|
|
|
|
- name: adaptive_max_pool3d_forward(Tensor self, IntList output_size)
|
|
self: adaptive_max_pool3d_backward(grad, self, indices)
|
|
|
|
- name: avg_pool2d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad)
|
|
self: avg_pool2d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
|
|
|
- name: avg_pool3d_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad)
|
|
self: avg_pool3d_backward(grad, self, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
|
|
|
- name: fractional_max_pool2d_forward(Tensor self, IntList kernel_size, IntList output_size, Tensor random_samples)
|
|
self: fractional_max_pool2d_backward(grad, self, kernel_size, output_size, indices)
|
|
|
|
- name: max_pool2d_with_indices_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode)
|
|
self: max_pool2d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
|
|
|
|
- name: max_pool3d_with_indices_forward(Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode)
|
|
self: max_pool3d_with_indices_backward(grad, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
|
|
|
|
- name: max_unpool2d_forward(Tensor self, Tensor indices, IntList output_size)
|
|
self: max_unpool2d_backward(grad, self, indices, output_size)
|
|
|
|
- name: max_unpool3d_forward(Tensor self, Tensor indices, IntList output_size, IntList stride, IntList padding)
|
|
self: max_unpool3d_backward(grad, self, indices, output_size, stride, padding)
|
|
|
|
- name: thnn_conv_transpose2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation)
|
|
self, weight, bias: thnn_conv_transpose2d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, columns, ones, grad_input_mask)
|
|
|
|
- name: thnn_conv_transpose2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList output_padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask)
|
|
|
|
- name: thnn_conv_transpose3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList output_padding, IntList dilation)
|
|
self, weight, bias: thnn_conv_transpose3d_backward(grad, self, weight, kernel_size, stride, padding, output_padding, dilation, finput, fgrad_input, grad_input_mask)
|
|
|
|
- name: thnn_conv_transpose3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList output_padding, IntList dilation, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, 1, false, false, false, grad_input_mask)
|
|
|
|
- name: thnn_conv2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding)
|
|
self, weight, bias: thnn_conv2d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask)
|
|
|
|
- name: thnn_conv2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1}}, false, {{0, 0}}, 1, false, false, false, grad_input_mask)
|
|
|
|
- name: thnn_conv_depthwise2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation)
|
|
self, weight: thnn_conv_depthwise2d_backward(grad.contiguous(), self, weight, kernel_size, stride, padding, dilation, grad_input_mask)
|
|
bias: grad.contiguous().view({grad.size(0), grad.size(1), -1}).sum(0).sum(1)
|
|
|
|
- name: thnn_conv_depthwise2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, std::array<bool,2> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], {}, grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, self.size(1), false, false, false, grad_input_mask)
|
|
|
|
- name: thnn_conv3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding)
|
|
self, weight, bias: thnn_conv3d_backward(grad, self, weight, kernel_size, stride, padding, finput, fgrad_input, grad_input_mask)
|
|
|
|
- name: thnn_conv3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, Tensor finput, Tensor fgrad_input, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, {{1, 1, 1}}, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask)
|
|
|
|
- name: thnn_conv_dilated2d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation)
|
|
self, weight, bias: thnn_conv_dilated2d_backward(grad, self, weight, kernel_size, stride, padding, dilation, columns, ones, grad_input_mask)
|
|
|
|
- name: thnn_conv_dilated2d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0}}, 1, false, false, false, grad_input_mask)
|
|
|
|
- name: thnn_conv_dilated3d_forward(Tensor self, Tensor weight, IntList kernel_size, Tensor bias, IntList stride, IntList padding, IntList dilation)
|
|
self, weight, bias: thnn_conv_dilated3d_backward(grad, self, weight, kernel_size, stride, padding, dilation, columns, ones, grad_input_mask)
|
|
|
|
- name: thnn_conv_dilated3d_backward(Tensor grad_output, Tensor self, Tensor weight, IntList kernel_size, IntList stride, IntList padding, IntList dilation, Tensor columns, Tensor ones, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, {{0, 0, 0}}, 1, false, false, false, grad_input_mask)
|
|
|
|
# NN double backwards support
|
|
|
|
- name: adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self)
|
|
grad_output: adaptive_avg_pool2d(grad, { grad_output.size(-2), grad_output.size(-1) })
|
|
self: zeros_like(self)
|
|
|
|
- name: adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self)
|
|
grad_output: adaptive_avg_pool3d(grad, { grad_output.size(-3), grad_output.size(-2), grad_output.size(-1) })
|
|
self: zeros_like(self)
|
|
|
|
- name: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices)
|
|
grad_output: max_pool_double_backward(grad, indices, 2)
|
|
self: zeros_like(self)
|
|
|
|
- name: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices)
|
|
grad_output: max_pool_double_backward(grad, indices, 3)
|
|
self: zeros_like(self)
|
|
|
|
- name: avg_pool2d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad)
|
|
grad_output: avg_pool2d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
|
self: zeros_like(self)
|
|
|
|
- name: avg_pool3d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, bool ceil_mode, bool count_include_pad)
|
|
grad_output: avg_pool3d(grad, kernel_size, stride, padding, ceil_mode, count_include_pad)
|
|
self: zeros_like(self)
|
|
|
|
- name: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, Tensor output)
|
|
grad_output: elu_backward(grad, alpha, scale, input_scale, output)
|
|
output: grad * grad_output * input_scale * (output < 0).toType(grad.type())
|
|
|
|
- name: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList output_size, Tensor indices)
|
|
grad_output: max_pool_double_backward(grad, indices, 2)
|
|
self: zeros_like(self)
|
|
|
|
- name: glu_backward(Tensor grad_output, Tensor self, int64_t dim)
|
|
grad_output: glu_double_backward_grad_output(grad, self, dim)
|
|
self: glu_double_backward(grad, grad_output, self, dim)
|
|
|
|
- name: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val)
|
|
grad_output: hardtanh_backward(grad, self, min_val, max_val)
|
|
self: zeros_like(grad)
|
|
|
|
- name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction)
|
|
grad_output: kl_div_double_backward_grad_output(grad, self, target, reduction)
|
|
self: zeros_like(grad)
|
|
target: zeros_like(grad)
|
|
|
|
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction)
|
|
grad_output: l1_loss_double_backward_grad_output(grad, self, target, reduction)
|
|
self: zeros_like(grad)
|
|
|
|
- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer)
|
|
grad_output: log_sigmoid_backward(grad, self, buffer)
|
|
self: log_sigmoid_double_backward(grad * grad_output, self)
|
|
|
|
- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int64_t dim, Tensor self)
|
|
grad_output: grad - (grad * output.exp()).sum(dim, true)
|
|
self: log_softmax_double_backward(grad, grad_output, dim, output).type_as(self)
|
|
|
|
- name: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope)
|
|
grad_output: leaky_relu_backward(grad, self, negative_slope)
|
|
self: zeros_like(grad)
|
|
|
|
- name: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode, Tensor indices)
|
|
grad_output: max_pool_double_backward(grad, indices, 2);
|
|
self: zeros_like(self)
|
|
|
|
- name: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, IntList kernel_size, IntList stride, IntList padding, IntList dilation, bool ceil_mode, Tensor indices)
|
|
grad_output: max_pool_double_backward(grad, indices, 3);
|
|
self: zeros_like(self)
|
|
|
|
- name: max_unpool2d_backward(Tensor grad_output, Tensor self, Tensor indices, IntList output_size)
|
|
grad_output: max_unpool2d(grad, indices, output_size)
|
|
self: zeros_like(self)
|
|
|
|
- name: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction)
|
|
grad_output: mse_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
|
|
self: mse_loss_double_backward(grad * grad_output, self, reduction)
|
|
|
|
- name: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor weight, int64_t reduction, int64_t ignore_index, Tensor total_weight)
|
|
grad_output: nll_loss(grad, target, weight, reduction, ignore_index)
|
|
self: zeros_like(grad)
|
|
|
|
- name: nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor weight, int64_t reduction, int64_t ignore_index, Tensor total_weight)
|
|
grad_output: nll_loss2d(grad, target, weight, reduction, ignore_index)
|
|
self: zeros_like(grad)
|
|
|
|
- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training)
|
|
grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
|
|
self: zeros_like(grad)
|
|
|
|
- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, IntList padding)
|
|
grad_output: reflection_pad1d(grad, padding)
|
|
self: zeros_like(self)
|
|
|
|
- name: reflection_pad2d_backward(Tensor grad_output, Tensor self, IntList padding)
|
|
grad_output: reflection_pad2d(grad, padding)
|
|
self: zeros_like(self)
|
|
|
|
- name: replication_pad1d_backward(Tensor grad_output, Tensor self, IntList padding)
|
|
grad_output: replication_pad1d(grad, padding)
|
|
self: zeros_like(self)
|
|
|
|
- name: replication_pad2d_backward(Tensor grad_output, Tensor self, IntList padding)
|
|
grad_output: replication_pad2d(grad, padding)
|
|
self: zeros_like(self)
|
|
|
|
- name: replication_pad3d_backward(Tensor grad_output, Tensor self, IntList padding)
|
|
grad_output: replication_pad3d(grad, padding)
|
|
self: zeros_like(self)
|
|
|
|
- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction)
|
|
grad_output: smooth_l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
|
|
self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction)
|
|
|
|
- name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output)
|
|
grad_output: softplus_backward(grad, self, beta, threshold, output)
|
|
self: softplus_double_backward(grad * grad_output, self, beta, threshold)
|
|
|
|
- name: _softmax_backward_data(Tensor grad_output, Tensor output, int64_t dim, Tensor self)
|
|
grad_output: _softmax_backward_data(grad, output, dim, self)
|
|
self: softmax_double_backward(grad, grad_output, dim, output).type_as(self)
|
|
|
|
- name: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction)
|
|
grad_output: soft_margin_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
|
|
self: soft_margin_loss_double_backward(grad * grad_output, self, target, reduction)
|
|
|
|
- name: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd)
|
|
grad_output: softshrink_backward(grad, self, lambd)
|
|
self: zeros_like(grad)
|
|
|
|
- name: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold)
|
|
grad_output: threshold_backward(grad, self, threshold)
|
|
self: zeros_like(grad)
|
|
|
|
- name: upsample_linear1d_backward(Tensor grad_output, IntList output_size, IntList input_size, bool align_corners)
|
|
grad_output: upsample_linear1d(grad, output_size, align_corners)
|
|
|
|
- name: upsample_bilinear2d_backward(Tensor grad_output, IntList output_size, IntList input_size, bool align_corners)
|
|
grad_output: upsample_bilinear2d(grad, output_size, align_corners)
|
|
|
|
- name: upsample_trilinear3d_backward(Tensor grad_output, IntList output_size, IntList input_size, bool align_corners)
|
|
grad_output: upsample_trilinear3d(grad, output_size, align_corners)
|
|
|
|
- name: upsample_nearest1d_backward(Tensor grad_output, IntList output_size, IntList input_size)
|
|
grad_output: upsample_nearest1d(grad, output_size)
|
|
|
|
- name: upsample_nearest2d_backward(Tensor grad_output, IntList output_size, IntList input_size)
|
|
grad_output: upsample_nearest2d(grad, output_size)
|
|
|
|
- name: upsample_nearest3d_backward(Tensor grad_output, IntList output_size, IntList input_size)
|
|
grad_output: upsample_nearest3d(grad, output_size)
|
|
|
|
- name: _sigmoid_backward(Tensor grad_output, Tensor output)
|
|
grad_output: _sigmoid_backward(grad, output)
|
|
output: grad * grad_output * (-2 * output + 1)
|
|
|
|
- name: _tanh_backward(Tensor grad_output, Tensor output)
|
|
grad_output: _tanh_backward(grad, output)
|
|
output: -2 * output * grad * grad_output
|
|
|
|
# cudnn
|
|
- name: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, IntList input_lengths, IntList target_lengths, int64_t blank, bool deterministic)
|
|
log_probs: result1
|
|
|
|
- name: cudnn_convolution_transpose(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic)
|
|
self, weight, bias: cudnn_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask)
|
|
|
|
- name: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask)
|
|
|
|
- name: cudnn_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic)
|
|
self, weight, bias: cudnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask)
|
|
|
|
- name: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)
|
|
|
|
# The above backward definitions are equivalent to the definitions below. Why do we bundle
|
|
# everything up? It's because it's more convenient to define double backwards
|
|
# when there is a single function that manages everything.
|
|
#
|
|
# Unfortuantely, there's one downside to not doing it all in one day: we
|
|
# unconditionally save input and weight, even if weight/input gradients are not
|
|
# being computed. That's too bad.
|
|
#
|
|
# input: cudnn_convolution_backward_input(input.sizes(), grad.contiguous(), weight, padding, stride, dilation, groups, benchmark, deterministic)
|
|
# weight: cudnn_convolution_backward_weight(weight.sizes(), grad.contiguous(), input, padding, stride, dilation, groups, benchmark, deterministic)
|
|
# bias: cudnn_convolution_backward_bias(grad.contiguous())
|
|
#
|
|
# input: cudnn_convolution_transpose_backward_input(grad.contiguous(), weight, padding, stride, dilation, groups, benchmark, deterministic)
|
|
# weight: cudnn_convolution_transpose_backward_weight(weight.sizes(), grad.contiguous(), input, padding, stride, dilation, groups, benchmark, deterministic)
|
|
# bias: cudnn_convolution_backward_bias(grad.contiguous())
|
|
|
|
- name: cudnn_grid_sampler(Tensor self, Tensor grid)
|
|
self, grid: cudnn_grid_sampler_backward(self, grid, grad)
|
|
|
|
- name: cudnn_affine_grid_generator(Tensor theta, int64_t N, int64_t C, int64_t H, int64_t W)
|
|
theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W)
|
|
|
|
# NB: Why is the backwards here so complicated? CuDNN cannot be used to compute
|
|
# backward in evaluation mode, because the math for backward in evaluation mode
|
|
# is different (since the forward math is different), and CuDNN does not support
|
|
# it. And in any case, you shouldn't be using this bn in evaluation mode,
|
|
# because it should be merged into the previous convolution (left for future
|
|
# work.)
|
|
# NB2: The quotes around the gradient are needed to appease YAML parsing rules.
|
|
- name: cudnn_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon)
|
|
input, weight, bias: "training ? cudnn_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)"
|
|
|
|
# HACK: save_mean and save_var are going to be passed in as
|
|
# requires_grad variables (even though we'll never backprop through
|
|
# them) so we need to prevent the unpacking from triggering an error.
|
|
- name: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon)
|
|
save_mean: not_implemented("cudnn_batch_norm_backward save_mean")
|
|
save_var: not_implemented("cudnn_batch_norm_backward save_var")
|
|
input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
|
|
|
|
|
|
# Only frst three of _cudnn_rnn outputs can have gradients.
|
|
# _cudnn_rnn outputs: (output, hy, cy, reserve, weight_buf)
|
|
- name: _cudnn_rnn(Tensor input, TensorList weight, int64_t weight_stride0, Tensor weight_buf, Tensor hx, Tensor cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, IntList batch_sizes, Tensor dropout_state)
|
|
output_differentiability: [True, True, True, False, False]
|
|
input, hx, cx, weight: "_cudnn_rnn_backward(input, weight, weight_stride0, result4, hx, cx, result0, grads[0], grads[1], grads[2], mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variables ? result3.clone() : result3, grad_input_mask)"
|
|
|
|
# miopen
|
|
|
|
- name: miopen_convolution_transpose(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic)
|
|
self, weight, bias: miopen_convolution_transpose_backward(self, grad, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask)
|
|
|
|
- name: miopen_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList output_padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, true, output_padding, groups, benchmark, deterministic, true, grad_input_mask)
|
|
|
|
- name: miopen_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic)
|
|
self, weight, bias: miopen_convolution_backward(self, grad, weight, padding, stride, dilation, groups, benchmark, deterministic, grad_input_mask)
|
|
|
|
- name: miopen_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic, std::array<bool,3> output_mask)
|
|
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, benchmark, deterministic, true, grad_input_mask)
|
|
|
|
- name: miopen_batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon)
|
|
input, weight, bias: "training ? miopen_batch_norm_backward(input, grad.contiguous(), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)"
|
|
|
|
- name: miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor save_mean, Tensor save_var, double epsilon)
|
|
save_mean: not_implemented("miopen_batch_norm_backward save_mean")
|
|
save_var: not_implemented("miopen_batch_norm_backward save_var")
|
|
input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask)
|
|
|
|
# mkldnn
|
|
- name: mkldnn_convolution(Tensor self, Tensor weight, Tensor bias, IntList padding, IntList stride, IntList dilation, int64_t groups)
|
|
self, weight, bias: mkldnn_convolution_backward(self, grad, weight, padding, stride, dilation, groups, grad_input_mask)
|
|
|
|
# fft
|
|
- name: _fft_with_size(Tensor self, int64_t signal_ndim, bool complex_input, bool complex_output, bool inverse, IntList checked_signal_sizes, bool normalized, bool onesided, IntList output_sizes)
|
|
self: fft_backward(self, grad, signal_ndim, complex_input, complex_output, inverse, checked_signal_sizes, normalized, onesided, output_sizes)
|
|
|
|
- name: unbind(Tensor self, int64_t dim)
|
|
self: unbind_backward(grads, dim)
|
|
|
|
- name: stack(TensorList tensors, int64_t dim)
|
|
tensors: unbind(grad, dim)
|
|
|
|
# fused RNN kernels
|
|
|
|
# Only frst two of _thnn_fused_lstm_cell outputs can have gradients.
|
|
# _thnn_fused_lstm_cell outputs: (hy, cy, workspace)
|
|
- name: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor input_bias, Tensor hidden_bias)
|
|
output_differentiability: [True, True, False]
|
|
input_gates, hidden_gates, cx, input_bias, hidden_bias: _thnn_fused_lstm_cell_backward(grads[0], grads[1], cx, result1, result2, input_bias.defined())
|
|
|
|
- name: _thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor input_bias, Tensor hidden_bias)
|
|
input_gates, hidden_gates, hx, input_bias, hidden_bias: _thnn_fused_gru_cell_backward(grad, result1, input_bias.defined())
|
|
|
|
# PackedSequence helpers
|
|
- name: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first)
|
|
input: _pack_padded_sequence_backward(grad, input.sizes(), result1, batch_first)
|