mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: - Move batch norm from TH(CU)NN to native - Speedups in many cases (e.g. #12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode) - It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation. - We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision. Compared to the ill-fated #12368 - I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own) - I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`. - I have merged master Pull Request resolved: https://github.com/pytorch/pytorch/pull/13263 Differential Revision: D12946254 Pulled By: SsnL fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
2048 lines
80 KiB
C++
2048 lines
80 KiB
C++
// NB: Must be at the top of file to avoid including the deprecated "math.h".
|
|
// https://stackoverflow.com/questions/6563810/m-pi-works-with-math-h-but-not-with-cmath-in-visual-studio
|
|
#ifdef _MSC_VER
|
|
#define _USE_MATH_DEFINES
|
|
#include <cmath>
|
|
#endif
|
|
|
|
#include "Functions.h"
|
|
#include <ATen/Utils.h>
|
|
#include <ATen/core/TensorOptions.h>
|
|
#include <ATen/WrapDimUtils.h>
|
|
#include <ATen/WrapDimUtilsMulti.h>
|
|
#include <ATen/SparseTensorUtils.h>
|
|
#include <ATen/ExpandUtils.h>
|
|
#include <ATen/core/Reduction.h>
|
|
|
|
#include <ciso646>
|
|
#include <algorithm>
|
|
#include <numeric>
|
|
#include <functional>
|
|
|
|
// ${generated_comment}
|
|
|
|
using at::Tensor;
|
|
using at::Scalar;
|
|
using at::IntList;
|
|
using at::TensorList;
|
|
|
|
namespace torch { namespace autograd { namespace generated {
|
|
|
|
namespace {
|
|
|
|
// Helper functions for autogenerated code
|
|
|
|
// A simple way to imperatively compute index ranges for slots
|
|
// that have been flattened
|
|
struct IndexRangeGenerator {
|
|
IndexRange range(size_t range_size) {
|
|
i += range_size;
|
|
return {i - range_size, i};
|
|
}
|
|
size_t size() { return i; }
|
|
private:
|
|
size_t i = 0;
|
|
};
|
|
|
|
void copy_range(variable_list& out, IndexRange range, const Tensor & t) {
|
|
AT_ASSERT(range.second <= out.size());
|
|
AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output");
|
|
out[range.first] = t;
|
|
}
|
|
|
|
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
|
|
AT_ASSERT(range.second <= out.size());
|
|
AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output");
|
|
std::copy(t.begin(), t.end(), out.begin() + range.first);
|
|
}
|
|
|
|
Tensor not_implemented(const char* name) {
|
|
throw std::runtime_error(
|
|
std::string("the derivative for '") + name + "' is not implemented");
|
|
}
|
|
|
|
Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
|
|
bool is_one = false;
|
|
if (s.isFloatingPoint()) {
|
|
is_one = s.toDouble() == 1;
|
|
} else if(s.isIntegral()) {
|
|
is_one = s.toLong() == 1;
|
|
}
|
|
|
|
if (is_one) {
|
|
return t;
|
|
} else {
|
|
return t * s;
|
|
}
|
|
}
|
|
|
|
int64_t _safe_size(IntList sizes, int64_t dim) {
|
|
dim = at::maybe_wrap_dim(dim, sizes.size());
|
|
return sizes.size() != 0 ? sizes[dim] : 1;
|
|
}
|
|
|
|
Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_, const Tensor & norm) {
|
|
double p = p_.toDouble();
|
|
Tensor self_scaled;
|
|
Tensor scale_v;
|
|
if (p == 0.0) {
|
|
return zeros_like(self);
|
|
} else if (p == 1.0) {
|
|
return self.sign() * grad;
|
|
} else if (p == 2.0) {
|
|
self_scaled = self;
|
|
scale_v = grad / norm;
|
|
} else if (std::isinf(p)) {
|
|
self_scaled = self.sign() * (self.abs() == norm).toType(self.type());
|
|
scale_v = grad.clone();
|
|
} else if (p < 2.0) {
|
|
self_scaled = self.sign() * self.abs().pow(p - 1);
|
|
scale_v = grad / norm.pow(p - 1);
|
|
} else {
|
|
self_scaled = self * self.abs().pow(p - 2);
|
|
scale_v = grad / norm.pow(p - 1);
|
|
}
|
|
// handle case at 0 where we return a subgradient containing 0
|
|
scale_v.masked_fill_(norm == 0, 0);
|
|
return self_scaled * scale_v;
|
|
}
|
|
|
|
Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, Tensor norm, int64_t dim, bool keepdim) {
|
|
if (!keepdim && self.dim() != 0) {
|
|
grad = grad.unsqueeze(dim);
|
|
norm = norm.unsqueeze(dim);
|
|
}
|
|
return norm_backward(grad, self, p_, norm);
|
|
}
|
|
|
|
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
|
|
double exponent = exponent_.toDouble();
|
|
if (exponent == 0.0) {
|
|
return zeros_like(self);
|
|
} else {
|
|
return grad * exponent * self.pow(exponent - 1);
|
|
}
|
|
}
|
|
|
|
Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
|
|
return at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * exponent * self.pow(exponent - 1));
|
|
}
|
|
|
|
Tensor pow_backward_exponent(Tensor grad, const Tensor & self, const Tensor & exponent) {
|
|
return grad * self.pow(exponent) * self.log();
|
|
}
|
|
|
|
Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor & exponent) {
|
|
return grad * at::pow(base, exponent) * std::log(base.toDouble());
|
|
}
|
|
|
|
Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
|
|
Tensor args = at::arange(-p + 1, 1, -1, self.options()).div_(2.);
|
|
args = args.add(self.unsqueeze(-1));
|
|
return grad * args.digamma_().sum(-1).add_(p * (p - 1) * std::log(M_PI) / 4.);
|
|
}
|
|
|
|
Tensor permute_backwards(const Tensor & grad, IntList fwd_dims) {
|
|
// invert the permutation
|
|
auto ndims = fwd_dims.size();
|
|
std::vector<int64_t> dims(ndims);
|
|
for (size_t i = 0; i < ndims; i++) {
|
|
dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
|
|
}
|
|
return grad.permute(dims);
|
|
}
|
|
|
|
Tensor sum_backward(const Tensor & grad, IntList sizes, IntList dims, bool keepdim) {
|
|
if (!keepdim && sizes.size() > 0) {
|
|
if (dims.size()==1) {
|
|
return grad.unsqueeze(dims[0]).expand(sizes);
|
|
} else {
|
|
auto dims_to_unsqueeze = dim_list_to_bitset(dims, sizes.size());
|
|
Tensor res = grad;
|
|
for (size_t i = 0; i < sizes.size(); i++){
|
|
if (dims_to_unsqueeze[i]) {
|
|
res = res.unsqueeze(i);
|
|
}
|
|
}
|
|
return res.expand(sizes);
|
|
}
|
|
} else {
|
|
return grad.expand(sizes);
|
|
}
|
|
}
|
|
|
|
Tensor reverse_dim(const Tensor& t, int64_t dim) {
|
|
Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong));
|
|
return t.index_select(dim, index);
|
|
}
|
|
|
|
Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) {
|
|
if (inp.size(dim) == 1) {
|
|
return grad;
|
|
}
|
|
|
|
auto ones_size = inp.sizes().vec();
|
|
ones_size[dim] = 1;
|
|
Tensor ones = at::ones(ones_size, grad.options());
|
|
Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim);
|
|
Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);
|
|
|
|
Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim);
|
|
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
|
|
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);
|
|
|
|
return grad * (exclusive_normal * exclusive_reverse);
|
|
}
|
|
|
|
// note that the gradient for prod is equivalent to:
|
|
// cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
|
|
// input: [ a, b, c]
|
|
// cumprod(exclusive, normal): [1 , a, a * b]
|
|
// cumprod(exclusive, reverse): [b * c, c, 1]
|
|
// product: [b * c, a * c, a * b]
|
|
// and this is safe under input with 0s.
|
|
Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) {
|
|
if (input.dim() == 0) {
|
|
return grad;
|
|
}
|
|
Tensor zero_idx = (input == 0).nonzero();
|
|
if (zero_idx.numel() == 0) {
|
|
return (grad * result) / input;
|
|
} else if (zero_idx.size(0) > 1) {
|
|
return zeros_like(input);
|
|
} else {
|
|
return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input);
|
|
}
|
|
}
|
|
|
|
Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) {
|
|
if (input.dim() == 0) {
|
|
return grad;
|
|
}
|
|
dim = at::maybe_wrap_dim(dim, input.sizes().size());
|
|
if (!keepdim && input.dim() != 1) {
|
|
grad = grad.unsqueeze(dim);
|
|
result = result.unsqueeze(dim);
|
|
}
|
|
|
|
Tensor zero_mask = (input == 0);
|
|
Tensor slice_zero_count = zero_mask.sum(dim, true);
|
|
int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
|
|
if (total_zeros == 0) {
|
|
return (grad * result) / input;
|
|
} else {
|
|
return prod_safe_zeros_backward(grad, input, dim);
|
|
}
|
|
}
|
|
|
|
Tensor sum_scan_exclusive(const Tensor& x, int64_t dim) {
|
|
Tensor ret = at::cumsum(-x, dim);
|
|
|
|
int64_t end_idx = ret.size(dim) - 1;
|
|
Tensor ret_sum = ret.narrow(dim, end_idx, 1).clone();
|
|
ret -= ret_sum.expand_as(ret);
|
|
ret += x;
|
|
return ret;
|
|
}
|
|
|
|
Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim) {
|
|
/*
|
|
There are two algorithms to do this. The first one
|
|
is very efficient, but works only when there are no
|
|
nonzero elements in the input.
|
|
|
|
The second one is much more complex, but it doesn't
|
|
assume anything on the input. The main downside is
|
|
that it takes time O(n^2), where n = input.size(self.dim)
|
|
(i.e. the length of the cumulative product). This is in
|
|
contrast to the forward pass and the efficient algorithm,
|
|
which are both O(n).
|
|
|
|
The second algorithm is a simple application of the chain
|
|
rule. If x is an n-dimensional vector, and y = cumprod(x),
|
|
and F is the final cost, then
|
|
|
|
dF / dx_k = sum_j (dF / dy_j) * (dy_j / dx_k) (1)
|
|
|
|
The term dF / dy_j is just grad_output[j] (assuming again
|
|
everything is one-dimensional).
|
|
|
|
The term (dy_j / dx_k) is easilly seen to be
|
|
|
|
if j >= k
|
|
dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i
|
|
else:
|
|
dy_j / dx_k = 0
|
|
|
|
Note that the indicator (j>=k) can be taken out
|
|
by replacing the sum in (1) with a sum from
|
|
j = k to n.
|
|
|
|
Thus,
|
|
df / dx_k = sum_{k <= j <= n} grad_output[j] * (dy_j / dx_k)
|
|
|
|
with
|
|
dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i (2)
|
|
|
|
Note that this last term is just the cumulative product
|
|
with k omitted. Thus, if x_k (the input) is nonzero, we can
|
|
just express this as
|
|
|
|
dy_j / dx_k = (prod_{1 <= i <= j} x_i) / x_k
|
|
= y_j / x_k
|
|
|
|
So therefore,
|
|
|
|
df / dx_k = sum_{k <= j <= n} grad_output[j] * y_j / x_k
|
|
|
|
so
|
|
|
|
grad_output = sum_scan_exclusiv(grad_output * output) / input
|
|
|
|
If the input is nonzero, we need to calculate the dy_j / dx_k
|
|
by using the formula (2), called in the code omitted_products.
|
|
|
|
The way the code calculates it is simply by noting that
|
|
|
|
prod_{1 <= i <= j, i != k} x_i
|
|
= (prod_{1 <= i <= k} x_i) * (prod_{k + 1 <= i <= j} x_i)
|
|
|
|
the first term is calculated as prods_until_k, which since
|
|
doesn't depend in j is easy to vectorize.
|
|
|
|
The second term (indexed by j) is the cumulative product of
|
|
x_{k+1}, x_{k+2}, ..., x_n, and it's named in the code
|
|
prods_from_k_pkus_1, and it's calculated as a cumprod.
|
|
|
|
In order to vectorize this properly, we need to add to
|
|
omitted_products the dimensions where k > j, and therefore
|
|
dy_j / dx_k = 0, which is done right after the assert.
|
|
*/
|
|
|
|
if (input.dim() == 0) {
|
|
return grad;
|
|
}
|
|
dim = at::maybe_wrap_dim(dim, input.sizes().size());
|
|
int64_t dim_size = input.size(dim);
|
|
if (dim_size == 1) {
|
|
return grad;
|
|
}
|
|
|
|
// Simple case with nonzero elements in the input
|
|
if ((input != 0).all().item<uint8_t>()) {
|
|
Tensor result = at::cumprod(input, dim);
|
|
return sum_scan_exclusive(result * grad, dim) / input;
|
|
}
|
|
|
|
auto ones_size = input.sizes().vec();
|
|
ones_size[dim] = 1;
|
|
Tensor ones = at::ones({1}, grad.options()).expand(ones_size);
|
|
Tensor grad_input = at::zeros(input.sizes(), grad.options());
|
|
Tensor prods_from_k_plus_1;
|
|
Tensor omitted_products;
|
|
for (int k = 0; k < dim_size; ++k) {
|
|
if (k == 0) {
|
|
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k + 1), dim);
|
|
omitted_products = at::cat({ones, prods_from_k_plus_1}, dim);
|
|
} else if (k == dim_size - 1) {
|
|
Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
|
|
omitted_products = prods_until_k;
|
|
} else {
|
|
Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
|
|
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k+1), dim);
|
|
omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
|
|
omitted_products = at::cat({prods_until_k, omitted_products}, dim);
|
|
}
|
|
|
|
// At this point omitted_products is the same size
|
|
// as input, except on the dimension dim where it's
|
|
// dim_size - k
|
|
AT_ASSERT(omitted_products.size(dim) == dim_size - k);
|
|
|
|
grad_input.select(dim, k).copy_(
|
|
at::sum(grad.slice(dim, k) * omitted_products,dim));
|
|
}
|
|
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim, ScalarType dtype) {
|
|
return cumprod_backward(grad.to(input.scalar_type()), input, dim);
|
|
}
|
|
|
|
Tensor gesv_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
|
|
return std::get<0>(at::gesv(grad, A.transpose(-2, -1)));
|
|
}
|
|
|
|
Tensor gesv_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
|
|
Tensor grad_self = gesv_backward_self(grad, self, A);
|
|
if (self.ndimension() == 2 && A.ndimension() == 2) {
|
|
return -at::mm(grad_self, solution.transpose(-2, -1));
|
|
}
|
|
return -at::matmul(grad_self, solution.transpose(-2, -1));
|
|
}
|
|
|
|
Tensor cumsum_backward(const Tensor & x, int64_t dim) {
|
|
if (x.dim() == 0) {
|
|
return x;
|
|
}
|
|
auto ret = at::cumsum(-x, dim);
|
|
auto ret_sum = ret.narrow(dim, ret.size(dim) - 1, 1).clone();
|
|
ret -= ret_sum.expand(ret.sizes());
|
|
ret += x;
|
|
return ret;
|
|
}
|
|
|
|
Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) {
|
|
return cumsum_backward(x.to(input_dtype), dim);
|
|
}
|
|
|
|
Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim, bool keepdim) {
|
|
if (!keepdim && self.dim() != 0) {
|
|
grad = grad.unsqueeze(dim);
|
|
result = result.unsqueeze(dim);
|
|
}
|
|
return grad * (self - result).exp();
|
|
}
|
|
|
|
Tensor unbind_backward(const variable_list& grads, int64_t dim) {
|
|
IntList sizes;
|
|
at::TensorOptions o;
|
|
for (auto v : grads) {
|
|
if (v.defined()) {
|
|
sizes = v.sizes();
|
|
o = static_cast<Tensor>(v).options();
|
|
break;
|
|
}
|
|
}
|
|
auto grads_tensors = fmap(grads, [&](const Variable &v) { return (v.defined() ? static_cast<Tensor>(v): at::zeros({}, o).expand(sizes));});
|
|
return at::stack(grads_tensors, dim);
|
|
}
|
|
|
|
Tensor unsqueeze_to(const Tensor & self, IntList sizes) {
|
|
auto result = self;
|
|
|
|
int64_t nDims = sizes.size();
|
|
for (int64_t dim = 0; dim < nDims; dim++) {
|
|
if (sizes[dim] == 1) {
|
|
result = result.unsqueeze(dim);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntList sizes) {
|
|
dim = at::maybe_wrap_dim(dim, sizes.size());
|
|
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
|
|
// unsqueezing in the backward.
|
|
if (sizes.size() > 0 && sizes[dim] == 1) {
|
|
return self.unsqueeze(dim);
|
|
}
|
|
return self;
|
|
}
|
|
|
|
std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
|
|
dim = at::legacy_cat_wrap_dim(dim, sizes);
|
|
std::vector<Tensor> grad_inputs(sizes.size());
|
|
int64_t accumulate = 0;
|
|
for (size_t i = 0; i < sizes.size(); ++i) {
|
|
auto& shape = sizes[i];
|
|
// If input was empty tensor, gradInput should be empty tensor.
|
|
if (shape == std::vector<int64_t>({0})) {
|
|
grad_inputs[i] = at::zeros({0}, grad.options());
|
|
continue;
|
|
}
|
|
auto size = shape[dim];
|
|
accumulate += size;
|
|
grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
|
|
}
|
|
return grad_inputs;
|
|
}
|
|
|
|
Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional<Scalar> & min, const optional<Scalar> & max) {
|
|
// clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases.
|
|
if (max && min) {
|
|
return grad * ((self >= *min) * (self <= *max)).type_as(grad);
|
|
} else if (min) {
|
|
return grad * (self >= *min).type_as(grad);
|
|
} else if (max) {
|
|
return grad * (self <= *max).type_as(grad);
|
|
} else {
|
|
return grad;
|
|
}
|
|
}
|
|
|
|
Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, IntList sizes, IntList strides, const Scalar & alpha) {
|
|
// if input was column-major, return grad as column-order for efficiency
|
|
if (strides[0] == 1 && strides[1] == sizes[0]) {
|
|
return maybe_multiply(mat2.mm(grad.t()).t(), alpha);
|
|
} else {
|
|
return maybe_multiply(grad.mm(mat2.t()), alpha);
|
|
}
|
|
}
|
|
|
|
Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, const Tensor & mat1, const Scalar & alpha) {
|
|
// if input was column-major, return grad as column-order for efficiency
|
|
if (mat1.is_sparse()) {
|
|
throw std::runtime_error("calculating the gradient of a sparse Tensor argument to mm is not supported.");
|
|
}
|
|
at::IntList sizes = mat1.sizes();
|
|
at::IntList strides = mat1.strides();
|
|
if (strides[0] == 1 && strides[1] == sizes[0]) {
|
|
return maybe_multiply(mat2.mm(grad.t()).t(), alpha);
|
|
} else {
|
|
return maybe_multiply(grad.mm(mat2.t()), alpha);
|
|
}
|
|
}
|
|
|
|
Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntList sizes, IntList strides, const Scalar & alpha) {
|
|
// if input was column-major, return grad as column-order for efficiency
|
|
if (strides[0] == 1 && strides[1] == sizes[0]) {
|
|
return maybe_multiply(grad.t().mm(mat1).t(), alpha);
|
|
} else {
|
|
return maybe_multiply(mat1.t().mm(grad), alpha);
|
|
}
|
|
}
|
|
|
|
Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) {
|
|
auto transposed_sizes = self.transpose(dim, 0).sizes().vec();
|
|
auto flatten = [&](const Tensor & t) {
|
|
return t.transpose(dim, 0).contiguous().view({t.size(dim), -1});
|
|
};
|
|
auto unflatten = [&](const Tensor & t) {
|
|
return t.contiguous().view(transposed_sizes).transpose(dim, 0);
|
|
};
|
|
|
|
// renorm computes the norm over all dimensions except `dim`, which is why
|
|
// we need the flatten and unflatten business. TODO: simplify this when we
|
|
// add support for norm over multiple dimensions.
|
|
auto self_flat = flatten(self);
|
|
auto grad_flat = flatten(grad);
|
|
auto norm_flat = self_flat.norm(p, 1, true);
|
|
auto grad_output = (self_flat * grad_flat).sum(1, true);
|
|
auto nb = norm_backward(grad_output, self_flat, p, norm_flat, 1, true);
|
|
auto invnorm = (norm_flat + 1e-7).reciprocal();
|
|
auto grad_norm = unflatten(maxnorm * invnorm * (grad_flat - invnorm * nb));
|
|
auto norm = unflatten(norm_flat.expand_as(self_flat));
|
|
|
|
// TODO: remove the detach once comparison ops no longer require grad
|
|
auto mask = Variable(norm < maxnorm).detach();
|
|
return at::where(mask, grad, grad_norm);
|
|
}
|
|
|
|
Tensor sum_tensorlist(TensorList tl) {
|
|
if (tl.size() == 0) {
|
|
throw std::runtime_error("Can't sum tensorlist of size 0");
|
|
}
|
|
Tensor sum = tl[0];
|
|
for(size_t i = 1; i < tl.size(); ++i) {
|
|
sum = sum + tl[i];
|
|
}
|
|
return sum;
|
|
}
|
|
|
|
Tensor repeat_backward(Tensor grad, int64_t input_dims, IntList repeats) {
|
|
int64_t num_unsqueezed = grad.dim() - input_dims;
|
|
for (int64_t i = 0; i < num_unsqueezed; ++i) {
|
|
grad = grad.sum(0, false);
|
|
}
|
|
for (size_t j = num_unsqueezed; j < repeats.size(); ++j) {
|
|
int64_t repeat = repeats[j];
|
|
if (repeat == 1) {
|
|
continue;
|
|
}
|
|
int64_t dim = j - num_unsqueezed;
|
|
grad = sum_tensorlist(grad.chunk(repeat, dim));
|
|
}
|
|
return grad;
|
|
}
|
|
|
|
// p1m == 1 - p
|
|
Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) {
|
|
if (grad.requires_grad()) {
|
|
// Use autograd-friendly backward if double backward is required
|
|
return grad * (mask.type_as(grad) * (1. / p1m));
|
|
} else {
|
|
return at::_masked_scale(grad, mask, 1. / p1m);
|
|
}
|
|
}
|
|
|
|
Tensor select_equals_backward(Tensor grad, const Tensor & input, const Tensor & value) {
|
|
auto grad_input = zeros_like(input);
|
|
grad_input.masked_fill_(input == value, grad);
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntList sizes, bool keepdim) {
|
|
if (!keepdim && sizes.size() > 0) {
|
|
grad = grad.unsqueeze(dim);
|
|
indices = indices.unsqueeze(dim);
|
|
}
|
|
return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
|
|
}
|
|
|
|
Tensor slice_backward(Tensor grad, IntList input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
|
auto grad_input = at::zeros(input_sizes, grad.options());
|
|
grad_input.slice(dim, start, end, step).copy_(grad);
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor select_backward(Tensor grad, IntList input_sizes, int64_t dim, int64_t index) {
|
|
auto grad_input = at::zeros(input_sizes, grad.options());
|
|
grad_input.select(dim, index).copy_(grad);
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor trace_backward(const Tensor & grad, IntList sizes) {
|
|
if (sizes.size() != 2) {
|
|
throw std::runtime_error("expected matrix input");
|
|
}
|
|
|
|
auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
|
|
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
|
|
grad_input.index_fill_(0, indices, grad);
|
|
return grad_input.view(sizes);
|
|
}
|
|
|
|
Tensor unfold_backward(const Tensor & grad, IntList input_sizes, int64_t dim, int64_t size, int64_t step) {
|
|
|
|
int64_t numel = 1;
|
|
for (auto size : input_sizes) {
|
|
numel *= size;
|
|
}
|
|
|
|
auto idx = at::arange(0, numel, grad.options().dtype(at::kLong)).view(input_sizes);
|
|
auto idx_unfolded = idx.unfold(dim, size, step).contiguous().view(-1);
|
|
auto grad_input = at::zeros({numel}, grad.options());
|
|
grad_input.index_add_(0, idx_unfolded, grad.contiguous().view(-1));
|
|
return grad_input.view(input_sizes);
|
|
}
|
|
|
|
Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
|
|
return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
|
|
}
|
|
|
|
Tensor var_backward(Tensor grad, const Tensor & self, int64_t dim, bool unbiased, bool keepdim) {
|
|
if (self.dim() == 0) {
|
|
return var_backward(grad, self, unbiased);
|
|
}
|
|
if (!keepdim && self.dim() > 1) {
|
|
grad = grad.unsqueeze(dim);
|
|
}
|
|
return (2.0 / (self.size(dim) - unbiased)) * grad * (self - self.mean(dim, true));
|
|
}
|
|
|
|
Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntList sizes) {
|
|
int64_t numel = 1;
|
|
for (auto size : sizes) {
|
|
numel *= size;
|
|
}
|
|
auto mask_selected = grad.masked_select(mask);
|
|
auto diff_nelem = numel - mask_selected.numel();
|
|
if (diff_nelem > 0) {
|
|
// because mask_selected returns a 1-d tensor with size of masked elements that are 1,
|
|
// we need to fill out the rest with zeros then reshape back to tensor2's size.
|
|
auto zeros_fillin = at::zeros({diff_nelem}, grad.options());
|
|
mask_selected = at::cat({mask_selected, zeros_fillin}, 0);
|
|
}
|
|
return mask_selected.view(sizes);
|
|
}
|
|
|
|
Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
|
|
// cf. Iain Murray (2016); arXiv 1602.07527
|
|
if (upper) {
|
|
L = L.t();
|
|
grad = grad.t();
|
|
}
|
|
|
|
auto phi = [](const Tensor & A) -> Tensor {
|
|
auto B = A.tril();
|
|
B = B - 0.5 * at::diag(at::diag(B));
|
|
return B;
|
|
};
|
|
|
|
// make sure not to double-count variation, since
|
|
// only half of output matrix is unique
|
|
auto Lbar = grad.tril();
|
|
|
|
auto P = phi(at::mm(L.t(), Lbar));
|
|
Tensor S;
|
|
std::tie(S, std::ignore) = at::gesv(P + P.t(), L.t());
|
|
std::tie(S, std::ignore) = at::gesv(S.t(), L.t());
|
|
S = phi(S);
|
|
if (upper) {
|
|
S = S.t();
|
|
}
|
|
return S;
|
|
}
|
|
|
|
Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
|
|
IntList split_sizes, int64_t dim, IntList sizes, const Type &type) {
|
|
dim = at::maybe_wrap_dim(dim, sizes.size());
|
|
|
|
// it's possible some of the grads are not defined (represents tensors of all 0s).
|
|
// Since at::cat can't handle those, let's define them
|
|
std::vector<Tensor> grads_all_defined(grads.size());
|
|
for (size_t j = 0; j < grads.size(); ++j) {
|
|
if (grads[j].defined()) {
|
|
grads_all_defined[j] = grads[j];
|
|
} else {
|
|
auto length = split_sizes[j];
|
|
auto grad_size = sizes.vec();
|
|
grad_size[dim] = length;
|
|
grads_all_defined[j] = at::zeros(grad_size, type);
|
|
}
|
|
}
|
|
|
|
auto ret = at::cat(grads_all_defined, dim);
|
|
return ret;
|
|
}
|
|
|
|
Tensor split_backward(const std::vector<torch::autograd::Variable> &grads,
|
|
int64_t split_size, int64_t dim, IntList sizes, const Type &type) {
|
|
dim = at::maybe_wrap_dim(dim, sizes.size());
|
|
int64_t dim_size = sizes[dim];
|
|
int64_t num_splits = grads.size();
|
|
std::vector<int64_t> split_sizes(num_splits, split_size);
|
|
split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size);
|
|
return split_with_sizes_backward(grads, split_sizes, dim, sizes, type);
|
|
}
|
|
|
|
Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) {
|
|
AT_ASSERT(indices.dim() >= dim);
|
|
auto size = indices.sizes().slice(0, indices.dim() - dim).vec();
|
|
size.push_back(-1);
|
|
auto indices_view = indices.view(size);
|
|
return grad.contiguous().view(size).gather(-1, indices_view).view(indices.sizes());
|
|
}
|
|
|
|
Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) {
|
|
auto& gO = grad_output;
|
|
auto input_size = input.size(dim) / 2;
|
|
auto first_half = input.narrow(dim, 0, input_size);
|
|
auto second_half = input.narrow(dim, input_size, input_size);
|
|
auto sig_second_half = second_half.sigmoid();
|
|
auto one_sub_sig_second_half = 1 - sig_second_half;
|
|
auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;
|
|
|
|
auto ggI_first_half = grad.narrow(dim, 0, input_size);
|
|
auto ggI_second_half = grad.narrow(dim, input_size, input_size);
|
|
auto ggI_second_half_times_first_half = ggI_second_half * first_half;
|
|
|
|
auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
|
|
auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig;
|
|
auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig;
|
|
return at::cat({gI_first_half, gI_second_half}, dim);
|
|
}
|
|
|
|
Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) {
|
|
if (dim < 0) dim += input.dim();
|
|
auto sizes = input.sizes().vec();
|
|
sizes[dim] /= 2;
|
|
auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
|
|
return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]);
|
|
}
|
|
|
|
Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
|
|
auto result = kl_div_backward(grad, input, target, Reduction::None);
|
|
if (reduction == Reduction::Mean) {
|
|
return result.mean();
|
|
} else if (reduction == Reduction::Sum) {
|
|
return result.sum();
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// Compute derivatives for targets.
|
|
// Assume targets are given as probabilities (i.e. without taking the logarithm).
|
|
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction) {
|
|
if (reduction == Reduction::None) {
|
|
return grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
|
|
}
|
|
if (reduction == Reduction::Mean) {
|
|
return grad_output.mul(target.log().add_(1).sub_(self)).div_(target.numel()).masked_fill_(target == 0, 0.);
|
|
}
|
|
return grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
|
|
}
|
|
|
|
Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) {
|
|
Tensor grad_target;
|
|
if (pos_weight.defined()) {
|
|
grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight.mul(self.sigmoid().log_())).mul_(grad_output);
|
|
} else {
|
|
grad_target = self.mul(-grad_output);
|
|
}
|
|
|
|
if (weight.defined()) {
|
|
grad_target.mul_(weight);
|
|
}
|
|
|
|
if (reduction == Reduction::Mean) {
|
|
grad_target.div_(target.numel());
|
|
}
|
|
|
|
return grad_target;
|
|
}
|
|
|
|
Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) {
|
|
auto z = input.sigmoid();
|
|
return grad * (z - 1) * z;
|
|
}
|
|
|
|
Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
|
|
auto gO = grad_output;
|
|
auto ggI = grad;
|
|
|
|
auto ggI_output = ggI * output;
|
|
auto ggI_out_sum = ggI_output.sum(dim, true);
|
|
auto ggI_out_sum_output = ggI_out_sum * output;
|
|
auto gO_out_sum = (gO * output).sum(dim, true);
|
|
|
|
// gI calculation
|
|
auto gI_t0 = ggI_output * (gO - gO_out_sum);
|
|
auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum));
|
|
auto gI_t2 = ggI_out_sum_output * gO;
|
|
auto gI_t3 = ggI_out_sum_output * gO_out_sum;
|
|
return gI_t0 - gI_t1 - gI_t2 + gI_t3;
|
|
}
|
|
|
|
Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
|
|
auto z = output.exp();
|
|
return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad);
|
|
}
|
|
|
|
Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
|
|
auto output = l1_loss_backward(grad, input, target, Reduction::None);
|
|
if (reduction == Reduction::Mean) {
|
|
return output.mean();
|
|
} else if (reduction == Reduction::Sum) {
|
|
return output.sum();
|
|
}
|
|
return output;
|
|
}
|
|
|
|
Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
|
|
auto d = (input - target).abs();
|
|
auto grad_input = grad * (d < 1).toType(grad.type());
|
|
if (reduction == Reduction::Mean) {
|
|
grad_input /= input.numel();
|
|
}
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
|
|
if (reduction == Reduction::None) {
|
|
return smooth_l1_loss_backward(grad, input, target, reduction);
|
|
}
|
|
auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction);
|
|
return (r * grad).sum();
|
|
}
|
|
|
|
Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal) {
|
|
auto ndimension = input_sizes.size();
|
|
AT_ASSERT(ndimension == 1 || ndimension == 2);
|
|
|
|
if (ndimension == 1 || input_sizes[0] == input_sizes[1]) {
|
|
return grad.diag(diagonal);
|
|
}
|
|
|
|
// Input was a matrix but was not square
|
|
auto grad_input = at::zeros(input_sizes, grad.options());
|
|
auto diag = grad_input.diagonal(diagonal);
|
|
diag.copy_(grad);
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor diagonal_backward(const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
|
|
auto grad_input = at::zeros(input_sizes, grad.options());
|
|
auto diag = grad_input.diagonal(offset, dim1, dim2);
|
|
diag.copy_(grad);
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) {
|
|
auto grad_input = 2 * grad;
|
|
if (reduction == Reduction::Mean) {
|
|
grad_input /= input.numel();
|
|
}
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
|
|
if (reduction == Reduction::None) {
|
|
return mse_loss_backward(grad, input, target, reduction);
|
|
}
|
|
auto r = mse_loss_backward(ones_like(grad_output), input, target, reduction);
|
|
return (r * grad).sum();
|
|
}
|
|
|
|
Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
|
|
auto z = (input * -target).exp();
|
|
auto zplus1 = z + 1;
|
|
auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
|
|
if (reduction == Reduction::Mean) {
|
|
grad_input /= input.numel();
|
|
}
|
|
return grad_input;
|
|
}
|
|
|
|
Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
|
|
if (reduction == Reduction::None) {
|
|
return soft_margin_loss_backward(grad, input, target, reduction);
|
|
}
|
|
auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction);
|
|
return (r * grad).sum();
|
|
}
|
|
|
|
Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) {
|
|
auto x = (input * beta);
|
|
return _sigmoid_backward(grad, x.sigmoid()) * (x < threshold).toType(grad.type()) * beta;
|
|
}
|
|
|
|
|
|
// NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
|
|
//
|
|
// `storage_offset` is ignored for simplicity in this note. If you just want the
|
|
// full algorithm without explanation, scroll down to bottom of this note.
|
|
//
|
|
// Implementing the backward of as_strided is tricky because you have to deal
|
|
// with mappings that map one memory location to multiple indices, i.e., the
|
|
// output tensor has multiple indices pointing to **overlapping** memory
|
|
// addresses. This can happen in all in all sorts of weird cases. For example,
|
|
//
|
|
// x = torch.randn(15)
|
|
// x.as_strided([3, 3], [1, 0]) # "expand" case
|
|
// x.as_strided([3, 3], [2, 1]) # "size too large" case
|
|
// x.as_strided([3, 2], [3, 6]) # res[2, 0] points to 2*3 + 0*6 = 6
|
|
// # res[0, 1] points to 0*3 + 1*6 = 6
|
|
//
|
|
// Here is the general strategy we apply in implementing as_strided backward:
|
|
// 0. ??? (optimization step. we will talk about this later)
|
|
// 1. Create some underlying flattened tensor as if it is the base tensor
|
|
// representing the contiguous memory storage for both input and output.
|
|
// 2. Use the output geometry to scatter (or index_add) the gradients into
|
|
// this storage tensor.
|
|
// 3. ??? (fix for input tensor with overlapping memory. we will talk about
|
|
// this later)
|
|
// 4. Return the as_strided view of the storage tensor using input geometry.
|
|
//
|
|
// In step (2), if the output tensor does't have overlapping memory, we can
|
|
// safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
|
|
// otherwise, we must use `index_add` as gradients at different indices may need
|
|
// to be summed to a single location.
|
|
//
|
|
// For example, in this case:
|
|
//
|
|
// x = torch.randn(3)
|
|
// y = x.as_strided([3, 3], [1, 0]) # "expand" case
|
|
// # size [ 3, 3]
|
|
// # stride [ 1, 0]
|
|
// y.backward() # step (1): contiguous storagte tensor `s` of size 3, which
|
|
// is large enough to be used as underlying storage
|
|
// for `x` and `y`.
|
|
// s = [ 0, 0, 0]
|
|
// # step (2): since `y` has overlapping memory, index_add grad
|
|
// into `s` basing on `y`'s geometry, i.e.,
|
|
// s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
|
|
// s = [ 3, 3, 3]
|
|
// # step (4): as_strided view `s` using `x`'s geometry
|
|
// s = [ 3, 3, 3]
|
|
// grad_input = s.as_strided(x.size(), x.stride())
|
|
// = s.as_strided([3], [1])
|
|
// = [ 3, 3, 3]
|
|
//
|
|
// This is exactly what we would get if using `expand`. However, here the input
|
|
// tensor doesn't have overlapping memory. If it does, we must add an extra step
|
|
// before (4). Considering this case:
|
|
//
|
|
// t = torch.randn(3)
|
|
// x = t.expand(3, 3) # input with overlapping memory
|
|
// # size [3, 3]
|
|
// # stride [0, 1]
|
|
// y = x.as_strided([1], [1]) # contiguous output
|
|
// # size [1]
|
|
// # stride [1]
|
|
// y.backward() # step (1): contiguous storage tensor `s` of size 3, which
|
|
// is large enough to be used as underlying storage
|
|
// for `x` and `y`.
|
|
// s = [ 0, 0, 0]
|
|
// # step (2): scatter grad into `s` basing on `y`'s geometry
|
|
// s = [ 1, 0, 0]
|
|
// # step (4): as_strided view `s` using `x`'s geometry
|
|
// s = [ 1, 0, 0]
|
|
// grad_input = s.as_strided([3, 3], [0, 1])
|
|
// = s.as_strided([3, 3], [0, 1])
|
|
// = [[ 1, 0, 0],
|
|
// [ 1, 0, 0],
|
|
// [ 1, 0, 0]]
|
|
// Is this result correct?
|
|
//
|
|
// `x.as_strided([1], [1])` call is obviously equivalent with
|
|
// `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
|
|
// gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
|
|
// case, indexing `x` at any index in first column is also equivalent, and
|
|
// yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
|
|
// an `x.size(1)`-times difference between these gradients computed from other
|
|
// PyTorch ops and the gradient we got from as_strided.
|
|
//
|
|
// You might conclude that the gradients from as_strided is wrong. However,
|
|
// let's first see why they are actually reasonable. Consider the pointwise
|
|
// perturbations by `delta` anywhere in the first column of `x`. It will lead to
|
|
// a `delta` change in the same memory location, and then `y` will change by
|
|
// `delta`. So one can say the gradient should be exactly 1 at the first column,
|
|
// as given by our above procedure.
|
|
//
|
|
// In the above computation of numerical gradients, they only match the
|
|
// analytical results because strides and memory locations are considered in the
|
|
// forward pass, i.e., this op (including both forward and backward) is
|
|
// layout-aware.
|
|
//
|
|
// However, in PyTorch, most (probably all) other ops (forward and backward) are
|
|
// layout-agnostic. E.g.,
|
|
//
|
|
// t = torch.randn(1)
|
|
// x = t.expand(2)
|
|
// y = x.sum()
|
|
// y.backward()
|
|
//
|
|
// Layout-agnostic autograd (as it is currently in PyTorch) will give you
|
|
//
|
|
// gy = 1
|
|
// gx = [ 1, 1] # SumBackward: torch.ones_like(x)
|
|
// gt = [ 2] # ExpandBackward: gx.sum()
|
|
//
|
|
// Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
|
|
// (the other will also change by `delta`), `y` will change by `2 * delta`. So
|
|
// the gradients, if strides are taken into consideration, should be 2.
|
|
//
|
|
// Layout-aware autograd should give you
|
|
//
|
|
// gy = 1
|
|
// gx = [ 2, 2] # Because the backward considers the fact that the input `x`
|
|
// # is already expanded.
|
|
// gt = [ 2] # Layout-aware backward of expand is just a slicing because
|
|
// # the previous backward should have already taken care of
|
|
// # strides and made sure that gradients are the same along the
|
|
// # expanded dimension.
|
|
//
|
|
// As shown above, these two types are not compatible. Therefore, we must either
|
|
// make as_strided layout-agnostic, or make all other ops layout-aware.
|
|
//
|
|
// It is difficult to support layout-aware autograd (at least in the current
|
|
// codebase structure), because it would mean
|
|
// 1. storing tensor geometries of every input tensor for backward
|
|
// 2. depending on input geometry, the gradient computed from backward change
|
|
// 3. ideally enforcing gradient of T to always have same strides as T
|
|
// (although these two methods only differ when it comes to overlapping memory)
|
|
//
|
|
// Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
|
|
// giving the same output regardless of the input layout. We consider
|
|
// `input.stride()` as a separate independent fixed argument `input_stride`.
|
|
// Then, `as_strided(input, size, stride)` can be thought of as:
|
|
// 1. "Scatter" each value of `input` into a "storage" using storage location
|
|
// computed from the value's index in `input`, `input.size()` and
|
|
// `input_stride`, but if N values end up in the same location, the value
|
|
// is average of those N values (they will be the same value anyways).
|
|
//
|
|
// Formal description:
|
|
// Denote the set of all input indices that pointing to the same storage
|
|
// location `storage[n]` as `S(n)`, i.e.,
|
|
//
|
|
// S(n) = { index : <index, input_stride> == n, index is valid given input.size() },
|
|
//
|
|
// where `<x, y>` is the dot product between `x` and `y`.
|
|
//
|
|
// Then, the process is:
|
|
//
|
|
// storage[n] = Avg { S(n) }
|
|
//
|
|
// Note that all values in `S(n)` are the same (they point to the same
|
|
// memory location anyways, so this step doesn't change anything, but
|
|
// effectively avoids having the denpendency on the layout of `input`.
|
|
// I.e., the result holds fixed regardless of the layout of `input`, as
|
|
// long as `input_stride` is fixed.
|
|
//
|
|
// NOTE: for forward pass, we can equivalently simply selet any one of
|
|
// `S(n)` as `storage[n]`. However, cosnidering this as an average
|
|
// operation makes backward easier (so all values in set
|
|
// `{ grad_input[i] : i in S(n) }` are the same, and it can use the
|
|
// same geometry as input).
|
|
// 2. As usual, return the as_strided view of `storage` using required output
|
|
// `size` and `stride`.
|
|
//
|
|
// To backward through this layout-agnostic version, we simply add the following
|
|
// step:
|
|
// .... (scatter gradients into the storage tensor using output geometry)
|
|
// 3. For all storage location n, `storage[n] /= |S(n)|`.
|
|
// .... (return as_strided view of the storage tensor using input geometry)
|
|
//
|
|
// Finally, we note that these general operations are expensive, so we apply the
|
|
// following optimizations:
|
|
// Add step (0): For all output dimension `d` with output stride 0, sum the
|
|
// gradients along dimension `d` (don't keepdim), and remove
|
|
// dimension `d` from output size and stride.
|
|
// (An optimization for "expand" cases so we may avoid step (3))
|
|
// Only apply step (3) when input tensor has overlapping memory.
|
|
//
|
|
// FULL ALGORITHM:
|
|
// 0. For all output dimension `d` with output stride 0, sum the gradients
|
|
// along dimension `d` (don't keepdim), and remove dimension `d` from
|
|
// output size and stride.
|
|
// 1. Create some underlying flattened tensor as if it is the base tensor
|
|
// representing the contiguous memory storage for both input and output.
|
|
// 2. Use the output geometry to scatter (or index_add) the gradients into
|
|
// this storage tensor `storage`.
|
|
// 3. If input tensor has overlapping memory,
|
|
// For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
|
|
// number of indices in input geometry pointing to the same storage
|
|
// location `i` (i.e., `|S(i)|` in equations above).
|
|
// 4. Return the as_strided view of the storage tensor using input geometry.
|
|
//
|
|
// See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
|
|
// roughly detech overlapping memory.
|
|
|
|
|
|
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
|
|
//
|
|
// Checking memory overlap within a strided tensor is the special case of
|
|
// detecting memory overlap of two strided tensors, where the two tensors start
|
|
// at the same memory address. The later is HARD (see #8212).
|
|
//
|
|
// But even this special case isn't simple. This note describes a check for a
|
|
// even more constrained simple case where we can be certain that there is no
|
|
// overlap.
|
|
//
|
|
// The checking algorithm can be described as:
|
|
// 0. Return [ pass check ] if any dimension has size 0
|
|
// 1. Ignore all dimensions that have size 1
|
|
// 2. If no remaining dimensions, return [ pass check ]
|
|
// 3. Sort the remaining dimensions according to the strides decreasingly
|
|
// 4. Check that for each dimension k,
|
|
//
|
|
// stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
|
|
//
|
|
// That is equivalent to, after reording the dimensions so strides are
|
|
// in decreasing order, checking that stride of each dimension is larger
|
|
// than the maximum memory offset in a slice at that dimension.
|
|
//
|
|
// Obviously this check passes for contiguous tensors ( the dimensions will be
|
|
// already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
|
|
// than RHS ). Similarly, the check passes for tensors contiguous in all but
|
|
// the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
|
|
// exactly stride[-1] larger than RHS. (*)
|
|
//
|
|
// We will show that these view operations, including all our view operations
|
|
// *except for* general as_strided and unfold, also preserve this invariant:
|
|
//
|
|
// alias: Obviously preserves
|
|
//
|
|
// expand: All changed dimensions are removed in step (1)
|
|
//
|
|
// view: Consider the input dimensions as grouped into consecutive
|
|
// dimension "blocks", where dimensions are contiguous in each one.
|
|
// one. view only works when the output dimensions can also be
|
|
// grouped into the same consecutive blocks of same ordering.
|
|
//
|
|
// NB: this means that the number of elements and stride of the
|
|
// last dimension in each block is the same in input and
|
|
// output. (**)
|
|
//
|
|
// Notation:
|
|
// Consider a single such block B,
|
|
// ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ...
|
|
// start--^^^^ ^^^^^^^^^^^^--end
|
|
// Each B[i] denotes a dimension index such that B[i] = B[0] + i.
|
|
//
|
|
// We first show that in a tensor (i.e., input) satisfies the
|
|
// invariant, after sorting, the dimensions within each block
|
|
// still remain consecutive. (***)
|
|
//
|
|
// After removing dimensions of size 1, the dimensions within a
|
|
// block is already sorted by strides in descending order. So
|
|
// sorting all dimensions will not change the relative ordering
|
|
// among them.
|
|
//
|
|
// Assume that some block B is not consecutive after sorting,
|
|
// i.e., there exists a dimension d between B[0] and B[-1] in
|
|
// sorted order.
|
|
//
|
|
// By (*), we know that
|
|
// stride[B[0]]
|
|
// = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]]
|
|
// < \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[d]
|
|
// <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d]
|
|
// <= \sum{j > B[0]} (size[j] - 1) * stride[j],
|
|
//
|
|
// where the first < comes from sorting and
|
|
// the second <= comes from the fact that dimension d
|
|
// exists after step (1) and
|
|
// thus must have size greater
|
|
// than 1
|
|
// the third <= comes from the fact that each term in
|
|
// the sum is non-negative
|
|
//
|
|
// Then we have a countradiction as the invariant must not be
|
|
// satisfied at B[0]. So the original proposition is true.
|
|
//
|
|
// Now that we established the above claim (***), we consider the
|
|
// view operation as first sorting the dimensions (i.e., blocks),
|
|
// apply the original view (since it only cares dimensions being
|
|
// consecutive and contiguous withtin each block), and then undo
|
|
// the sort.
|
|
//
|
|
// Consider a single block B in the output,
|
|
// ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
|
|
// start--^^^^ ^^^^^^^^^^^^--end
|
|
//
|
|
// By (*), we know that for all i
|
|
// stride[i] = stride[B[-1]] +
|
|
// \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]]
|
|
//
|
|
// Then the invariant is obviously satisfied at every dimension
|
|
// in this block if it is satisfied at dimnesion B[-1]. It only
|
|
// remains to show that it is satisfied at the last dimension in
|
|
// each block.
|
|
//
|
|
// Since the same blocks are present in both input and output
|
|
// with the same ordering, we will abuse the notation in the
|
|
// following statements.
|
|
//
|
|
// By (*), we know that the following holds for both input and
|
|
// output, for any block B:
|
|
// \sum_{i > B[-1]} (size[i] - 1) * stride[i]
|
|
// = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]]
|
|
// = \sum_{block B' after B} numel(B') * stride[B'[-1]].
|
|
// ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
// By (**), we know that, this quantity in the above equation
|
|
// remains the same in input and output. So both
|
|
// \sum_{i > B[-1]} (size[i] - 1) * stride[i]
|
|
// and
|
|
// stride[B[-1]]
|
|
// are the same in input and output.
|
|
//
|
|
// These two quantities are exactly the LHS and RHS of the
|
|
// invariant inequality. Since by assumption the invariant is
|
|
// satisfied in input at B[-1], it is also satisfied in output at
|
|
// B[-1]. This concludes the proof.
|
|
//
|
|
// squeeze: Special case of view
|
|
//
|
|
// unsqueeze: Special case of view
|
|
//
|
|
// slice: Consider slicing dimension i with step = k >= 1.
|
|
//
|
|
// Let stride' and size' be the output strides and sizes. We have
|
|
//
|
|
// stride'[i] = k * stride[i]
|
|
// size'[i] <= floor(size[i] / k)
|
|
//
|
|
// If size'[i] = 1, invariant is obviously satisfied as we are
|
|
// just removing a dimension (afte step (1)).
|
|
//
|
|
// Assume size'[i] > 1.
|
|
//
|
|
// By assumption, the invariant is satisfied at every dimension
|
|
// in input.
|
|
//
|
|
// For any dimension j, if stride[j] > stride[i], we have
|
|
// stride'[j] = stride[j]
|
|
// > (size[i] - 1) * stride[i]
|
|
// = (size[i] / k * k - 1) * k * stride[i] / k
|
|
// = (size[i] / k - 1 / k) * stride'[i]
|
|
// >= (size'[i] - 1 / k) * stride'[i]
|
|
// >= stride'[i].
|
|
//
|
|
// If stride[j] < stride[i], we have
|
|
// stride'[j] = stride[j] < stride[i] <= stride'[i].
|
|
//
|
|
// So the sorting order remains unchanged after slice.
|
|
//
|
|
// Since
|
|
// (size'[i] - 1) * stride'[i]
|
|
// = (floor(size[i] / k) - 1) * k * stride[i]
|
|
// <= (size[i] / k - 1) * k * stride[i]
|
|
// = (size[i] - k) * stride[i]
|
|
// <= (size[i] - 1) * * stride[i],
|
|
// the term from this dimension i in the invariant inequality at
|
|
// other dimensions can only decrease after slice. So the
|
|
// invariant is preserved.
|
|
//
|
|
// narrow: Special case of slice
|
|
//
|
|
// select: narrow + squeeze
|
|
//
|
|
// permute: Sorting makes permutation of dimensions irrelevant
|
|
//
|
|
// transpose: Sorting makes swapping dimensions irrelevant
|
|
//
|
|
// diagonal: Effectively merging two dimensions i and j into a new
|
|
// dimension k s.t.
|
|
// stride'[k] = stride[i] + stride[j]
|
|
// size'[k] <= min(size[i], size[j]),
|
|
// where stride and size are on the input, and stride' and size'
|
|
// are on the output.
|
|
//
|
|
// Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
|
|
// then this is unsqueeze on that dimension.
|
|
//
|
|
// WLOG, say stride[i] >= stride[j].
|
|
//
|
|
// Each dimension d in input with stride[d] > stride[j] has
|
|
// stride'[d] = stride[d]
|
|
// > (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j]
|
|
// >= stride[i] + stride[j]
|
|
// = stride[k].
|
|
// So, considering the sorted dimensions, this is effectively
|
|
// removing i, and replacing j with k.
|
|
//
|
|
// For dimensions d with stride[i] < stride[d] < stride[j], the
|
|
// term from dimension i is removed in the invariant inequality.
|
|
// For dimensions d with stride[d] > stride[j], we have
|
|
// (size'[k] - 1) * stride'[k]
|
|
// <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
|
|
// <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
|
|
// so the term from i and j in the invariant can only decrease.
|
|
//
|
|
// So this is generally relaxing the constraint, and thus it
|
|
// preserves it.
|
|
|
|
// This implements steps (2)~(4) of the algorithm in
|
|
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
|
|
// Helper for as_strided_backward
|
|
static inline bool _maybe_overlapping_memory(IntList sizes, IntList strides) {
|
|
if (sizes.size() > 0) {
|
|
std::vector<std::size_t> argsort(sizes.size());
|
|
std::iota(argsort.begin(), argsort.end(), 0);
|
|
std::sort(argsort.begin(), argsort.end(),
|
|
[&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; });
|
|
|
|
int64_t max_index_in_slice = 0;
|
|
for (auto i : argsort) {
|
|
auto stride_ = strides[i];
|
|
if (stride_ <= max_index_in_slice) {
|
|
return true;
|
|
}
|
|
max_index_in_slice += stride_ * (sizes[i] - 1);
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
// Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset
|
|
// Helper for as_strided_backward
|
|
static inline int64_t _min_storage_size(IntList sizes, IntList strides, int64_t storage_offset) {
|
|
int64_t storage_size = storage_offset + 1;
|
|
int64_t dim = sizes.size();
|
|
for (int64_t i = 0; i < dim; i++) {
|
|
auto size_i = sizes[i];
|
|
if (size_i == 0) {
|
|
return storage_offset;
|
|
}
|
|
storage_size += (size_i - 1) * strides[i];
|
|
}
|
|
return storage_size;
|
|
}
|
|
|
|
// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation
|
|
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntList sizes, IntList strides, int64_t storage_offset) {
|
|
// For output geometry,
|
|
// check for size 0 dimensions,
|
|
// skip size 1 dimensions,
|
|
// reduce grad on expanded dims (stride=0, size>1)
|
|
// Step (0) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
|
|
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
|
|
// on output geometry
|
|
auto odim = grad.dim();
|
|
std::vector<int64_t> out_sizes_, out_strides_;
|
|
out_sizes_.reserve(odim);
|
|
out_strides_.reserve(odim);
|
|
for (int64_t i = odim - 1; i >= 0; i--) {
|
|
auto size_i = sizes[i];
|
|
auto stride_i = strides[i];
|
|
if (size_i == 0) {
|
|
return at::zeros(input_geometry.sizes(), grad.options());
|
|
} else if (size_i == 1) {
|
|
grad = grad.squeeze(i);
|
|
} else if (stride_i == 0) {
|
|
grad = grad.sum(i, false);
|
|
} else {
|
|
out_sizes_.insert(out_sizes_.begin(), size_i);
|
|
out_strides_.insert(out_strides_.begin(), stride_i);
|
|
}
|
|
}
|
|
// Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
|
|
// on output geometry
|
|
auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);
|
|
|
|
// For input geometry,
|
|
// check for size 0 dimensions,
|
|
// skip size 1 dimensions,
|
|
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
|
|
// on input geometry
|
|
auto idim = input_geometry.dim();
|
|
IntList inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides();
|
|
std::vector<int64_t> inp_sizes_, inp_strides_;
|
|
inp_sizes_.reserve(idim);
|
|
inp_strides_.reserve(idim);
|
|
for (int64_t i = idim - 1; i >= 0; i--) {
|
|
auto size_i = inp_sizes[i];
|
|
auto stride_i = inp_strides[i];
|
|
if (size_i == 0) {
|
|
return at::zeros(input_geometry.sizes(), grad.options());
|
|
} else if (size_i != 1) {
|
|
inp_sizes_.insert(inp_sizes_.begin(), size_i);
|
|
inp_strides_.insert(inp_strides_.begin(), stride_i);
|
|
}
|
|
}
|
|
// Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
|
|
// on input geometry
|
|
auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);
|
|
|
|
|
|
// Rest of this function implements
|
|
// Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
|
|
// TODO: Raise if not all output values are visible in input geometry.
|
|
// Technically speaking, if you treat those values as constants, not
|
|
// raising is fine, and mathematically correct. However, these values
|
|
// really are contained in some base tensor, and by treating them as
|
|
// constants we are ignoring this tight dependency. Therefore, it is
|
|
// more sensible to raise here.
|
|
|
|
// Step (1): create underlying tensor as "storage"
|
|
auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset);
|
|
auto inp_effective_offset = input_geometry.storage_offset() - shared_offset;
|
|
auto out_effective_offset = storage_offset - shared_offset;
|
|
auto base_size = std::max(
|
|
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
|
|
_min_storage_size(out_sizes_, out_strides_, out_effective_offset)
|
|
);
|
|
auto storage = at::zeros({base_size}, grad.options());
|
|
|
|
// prepare indices tensor if we will do index_add_ later
|
|
c10::optional<at::Tensor> flatten_full_indices;
|
|
if (inp_maybe_overlap || out_maybe_overlap) {
|
|
flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong));
|
|
}
|
|
|
|
// Step (2): use output geometry to scatter gradients into storage
|
|
if (out_maybe_overlap) {
|
|
auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset);
|
|
storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
|
|
} else {
|
|
// assume that new tensors have 0 storage offset
|
|
storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad);
|
|
}
|
|
|
|
// Step (3): if input tensor has overlapping memory, divide scattered gradient
|
|
// at storage[i] by the number of times i shows up in input geometry
|
|
if (inp_maybe_overlap) {
|
|
auto count = at::zeros_like(storage);
|
|
auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1);
|
|
count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
|
|
storage.div_(count); // this will give nan outside visible range
|
|
}
|
|
// Step (4): return as_strided view of the storage tensor with input geometry
|
|
return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset);
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
|
|
auto recip = (self * self + other * other).reciprocal();
|
|
return std::tuple<Tensor,Tensor>{
|
|
output_mask[0] ? grad * other * recip : Tensor(),
|
|
output_mask[1] ? grad * -self * recip : Tensor() };
|
|
}
|
|
|
|
// TODO: Seriously consider writing the derivative formulas for
|
|
// each output separately; there is not all that much sharing
|
|
// of computation going on here.
|
|
std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
|
|
const Tensor & grad_grad_input,
|
|
const Tensor & grad_grad_weight,
|
|
const Tensor & grad_out,
|
|
const Tensor & input_,
|
|
const Tensor & weight_) {
|
|
|
|
auto input = input_.contiguous();
|
|
auto weight = weight_.contiguous();
|
|
|
|
// Zero-fill undefined grads (TODO: do this more efficiently)
|
|
auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input);
|
|
auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight);
|
|
auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input);
|
|
|
|
auto positive_mask = (input > 0).type_as(ggI);
|
|
auto nonpositive_mask = (input <= 0).type_as(ggW);
|
|
|
|
// Explanation: Let input be i, weight be w, grad_output be gO.
|
|
// f(i, w) = i if i > 0
|
|
// = w * i if i <= 0
|
|
// gI = df/di * gO = gO if i > 0 gW = df/dw * gO = 0 if i > 0
|
|
// = gO * w if i <= 0 = gO * i if i <= 0
|
|
// The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly.
|
|
|
|
if (weight.numel() == 1) {
|
|
// from PReLU.forward: num_parameters == 0 is used indicate that a
|
|
// single weight is shared among all input channels.
|
|
|
|
// this is a little tricky because PReLU currently doesn't take a shape so the weight may be
|
|
// 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable
|
|
// and tensor are merged). So, use weight and ggW as 0-dim in this case.
|
|
bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1);
|
|
auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight;
|
|
auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW;
|
|
|
|
auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input);
|
|
auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input);
|
|
return std::tuple<Tensor, Tensor, Tensor>(
|
|
ggO,
|
|
ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask,
|
|
(ggI * gO * nonpositive_mask).sum().expand_as(weight)
|
|
);
|
|
} else {
|
|
// Expand ggW to match size of ggI; a simple expand doesn't work because
|
|
// ggW is the size of the input channel (dim==1 unless there is only 1 dimension). For example,
|
|
// let ggI be size (3,4,5,6,7) and ggW be size (4). Then we unsqueeze ggW to be size (4,1,1,1)
|
|
// so the expand succeeds.
|
|
auto dims_to_unsqueeze = std::max<int64_t>(input.dim() - 2, 0);
|
|
auto ggW_expanded = ggW;
|
|
for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
|
|
ggW_expanded = ggW_expanded.unsqueeze(1);
|
|
}
|
|
ggW_expanded = ggW_expanded.expand_as(ggI);
|
|
|
|
auto gI = ggW_expanded * gO * nonpositive_mask;
|
|
|
|
auto gW = ggI * gO * nonpositive_mask;
|
|
if (input.dim() > 1) {
|
|
gW = gW.sum(0);
|
|
}
|
|
while (gW.dim() > 1) {
|
|
gW = gW.sum(1);
|
|
}
|
|
|
|
Tensor ggO;
|
|
if (gO.requires_grad()) {
|
|
// expand weight as input as in ggW/ggI above
|
|
auto weight_expanded = weight;
|
|
for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
|
|
weight_expanded = weight_expanded.unsqueeze(1);
|
|
}
|
|
weight_expanded = weight_expanded.expand_as(input);
|
|
|
|
auto mask = positive_mask + nonpositive_mask * weight_expanded;
|
|
ggO = ggI * mask + ggW_expanded * nonpositive_mask * input;
|
|
}
|
|
return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
|
|
}
|
|
}
|
|
|
|
// https://j-towns.github.io/papers/svd-derivative.pdf
|
|
//
|
|
// This makes no assumption on the signs of sigma.
|
|
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
|
|
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
|
|
AT_CHECK(compute_uv,
|
|
"svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
|
|
"and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");
|
|
|
|
auto m = self.size(0);
|
|
auto n = self.size(1);
|
|
auto k = sigma.size(0);
|
|
auto gsigma = grads[1];
|
|
|
|
auto u = raw_u;
|
|
auto v = raw_v;
|
|
auto gu = grads[0];
|
|
auto gv = grads[2];
|
|
|
|
if (!some) {
|
|
// We ignore the free subspace here because possible base vectors cancel
|
|
// each other, e.g., both -v and +v are valid base for a dimension.
|
|
// Don't assume behavior of any particular implementation of svd.
|
|
u = raw_u.narrow(1, 0, k);
|
|
v = raw_v.narrow(1, 0, k);
|
|
if (gu.defined()) {
|
|
gu = gu.narrow(1, 0, k);
|
|
}
|
|
if (gv.defined()) {
|
|
gv = gv.narrow(1, 0, k);
|
|
}
|
|
}
|
|
auto vt = v.t();
|
|
|
|
Tensor sigma_term;
|
|
if (gsigma.defined()) {
|
|
sigma_term = u.mm(gsigma.diag()).mm(vt);
|
|
} else {
|
|
sigma_term = at::zeros({1}, self.options()).expand_as(self);
|
|
}
|
|
// in case that there are no gu and gv, we can avoid the series of kernel
|
|
// calls below
|
|
if (!gv.defined() && !gu.defined()) {
|
|
return sigma_term;
|
|
}
|
|
|
|
auto ut = u.t();
|
|
auto im = eye(m, self.options());
|
|
auto in = eye(n, self.options());
|
|
auto sigma_mat = sigma.diag();
|
|
auto sigma_mat_inv = sigma.pow(-1).diag();
|
|
auto sigma_expanded_sq = sigma.pow(2).expand_as(sigma_mat);
|
|
auto F = sigma_expanded_sq - sigma_expanded_sq.t();
|
|
// The following two lines invert values of F, and fills the diagonal with 0s.
|
|
// Notice that F currently has 0s on diagonal. So we fill diagonal with +inf
|
|
// first to prevent nan from appearing in backward of this function.
|
|
F.diagonal().fill_(INFINITY);
|
|
F = F.pow(-1);
|
|
|
|
Tensor u_term, v_term;
|
|
|
|
if (gu.defined()) {
|
|
u_term = u.mm(F.mul(ut.mm(gu) - gu.t().mm(u))).mm(sigma_mat);
|
|
if (m > k) {
|
|
u_term = u_term + (im - u.mm(ut)).mm(gu).mm(sigma_mat_inv);
|
|
}
|
|
u_term = u_term.mm(vt);
|
|
} else {
|
|
u_term = at::zeros({1}, self.options()).expand_as(self);
|
|
}
|
|
|
|
if (gv.defined()) {
|
|
auto gvt = gv.t();
|
|
v_term = sigma_mat.mm(F.mul(vt.mm(gv) - gvt.mm(v))).mm(vt);
|
|
if (n > k) {
|
|
v_term = v_term + sigma_mat_inv.mm(gvt.mm(in - v.mm(vt)));
|
|
}
|
|
v_term = u.mm(v_term);
|
|
} else {
|
|
v_term = at::zeros({1}, self.options()).expand_as(self);
|
|
}
|
|
|
|
return u_term + sigma_term + v_term;
|
|
}
|
|
|
|
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
|
|
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
|
|
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) {
|
|
AT_CHECK(eigenvectors,
|
|
"symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ",
|
|
"and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)");
|
|
|
|
auto glambda = grads[0];
|
|
auto gv = grads[1];
|
|
|
|
auto vt = v.t();
|
|
|
|
Tensor result;
|
|
if (gv.defined()) {
|
|
Tensor F = lambda.unsqueeze(0).expand_as(self).clone();
|
|
F.sub_(at::unsqueeze(lambda, 1));
|
|
F.diagonal().fill_(INFINITY);
|
|
F.pow_(-1);
|
|
|
|
F.mul_(vt.mm(gv));
|
|
result = v.mm(F.mm(vt));
|
|
} else {
|
|
result = at::zeros_like(self);
|
|
}
|
|
|
|
if (glambda.defined()) {
|
|
result.add_((v * glambda).mm(vt));
|
|
}
|
|
if (upper) {
|
|
result = at::triu(result) + at::triu(result.t(), 1);
|
|
} else {
|
|
result = at::tril(result) + at::tril(result.t(), -1);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
// Invertible case is derived from Jacobi's formula, and also can be found at:
|
|
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
|
|
Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
|
|
auto det_val = det.item<double>();
|
|
if (det_val != 0 /* invertible */) {
|
|
return grad * det * self.inverse().t();
|
|
} else /* otherwise det = \prod(sigma) = 0, use svd */ {
|
|
Tensor u, sigma, v;
|
|
std::tie(u, sigma, v) = self.svd();
|
|
auto gsigma = prod_backward(grad, sigma, det);
|
|
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
|
|
}
|
|
}
|
|
|
|
Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {
|
|
auto logdet_val = logdet.item<double>();
|
|
if (logdet_val != -INFINITY /* det != 0, invertible */) {
|
|
return grad * self.inverse().t();
|
|
} else /* otherwise det = \prod(sigma) = 0, use svd */ {
|
|
Tensor u, sigma, v;
|
|
std::tie(u, sigma, v) = self.svd();
|
|
// backward det = \sum log(sigma)
|
|
auto gsigma = grad.div(sigma);
|
|
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
|
|
}
|
|
}
|
|
|
|
Tensor slogdet_backward(const std::vector<torch::autograd::Variable> &grads,
|
|
const Tensor& self,
|
|
const Tensor& signdet, const Tensor& logabsdet) {
|
|
AT_ASSERTM(!grads[0].defined(), "slogdet's sign output should never have gradient");
|
|
auto signdet_val = signdet.item<double>();
|
|
if (signdet_val != 0 /* det != 0, invertible */) {
|
|
return grads[1] * self.inverse().t();
|
|
} else /* otherwise det = \prod(sigma) = 0, use svd */ {
|
|
Tensor u, sigma, v;
|
|
std::tie(u, sigma, v) = self.svd();
|
|
// sigma has all non-negative entries (also with at least one zero entry)
|
|
// so logabsdet = \sum log(abs(sigma))
|
|
// but det = 0, so backward logabsdet = \sum log(sigma)
|
|
auto gsigma = grads[1].div(sigma);
|
|
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
|
|
}
|
|
}
|
|
|
|
// Reference:
|
|
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
|
// Sec. 2.3.1 Matrix inverse product
|
|
std::tuple<Tensor, Tensor> trtrs_backward(
|
|
const Tensor & grad_x, const Tensor & grad_m,
|
|
const Tensor & b, const Tensor & a, const Tensor & x,
|
|
const bool upper, const bool transpose, const bool unitriangular,
|
|
std::array<bool, 2> output_mask) {
|
|
Tensor grad_b, grad_a;
|
|
if (grad_x.defined()) {
|
|
grad_b = std::get<0>(grad_x.trtrs(a, upper, !transpose, unitriangular));
|
|
if (output_mask[1]) {
|
|
grad_a = transpose ? -x.mm(grad_b.t()) : -grad_b.mm(x.t());
|
|
if (upper) {
|
|
grad_a = grad_a.triu((int) unitriangular);
|
|
} else {
|
|
grad_a = grad_a.tril(-((int) unitriangular));
|
|
}
|
|
}
|
|
}
|
|
if (!grad_a.defined()) {
|
|
grad_a = at::zeros({1}, a.options()).expand_as(a);
|
|
}
|
|
if (!grad_b.defined()) {
|
|
grad_b = at::zeros({1}, b.options()).expand_as(b);
|
|
}
|
|
if (output_mask[1] && grad_m.defined()) {
|
|
grad_a = grad_a.add(grad_m);
|
|
}
|
|
return std::tuple<Tensor, Tensor>{grad_b, grad_a};
|
|
}
|
|
|
|
// Generally speaking, fft's backward is ifft.
|
|
Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim,
|
|
bool complex_input, bool complex_output,
|
|
bool inverse, IntList checked_signal_sizes,
|
|
bool normalized, bool onesided,
|
|
IntList output_sizes) {
|
|
Tensor gI;
|
|
if (!complex_input && complex_output) {
|
|
// Forward is R2C
|
|
// Do inverse C2C and project onto real plane because grad can be
|
|
// asymmetrical so C2R can't be used.
|
|
if (onesided) {
|
|
// Forward is R2C (onesided)
|
|
// Think of onesided R2C rfft as
|
|
// 1. view as complex numbers (fill complex dim with zeros)
|
|
// 2. C2C fft
|
|
// 3. discard half of results
|
|
// So backward is
|
|
// 1. fill the other half with zeros (with `zero_grad_shape` below)
|
|
// (C2C ifft only take twosided inputs so we need to fill here)
|
|
// 2. inverse C2C ifft
|
|
// 3. discard the complex dim
|
|
int64_t zero_length = checked_signal_sizes[signal_ndim - 1] - grad.size(signal_ndim);
|
|
auto complex_full_grad = grad;
|
|
if (zero_length > 0) {
|
|
std::vector<int64_t> zero_grad_shape(signal_ndim + 2);
|
|
zero_grad_shape[0] = self.size(0);
|
|
for (int64_t i = 1; i < signal_ndim; i++) {
|
|
zero_grad_shape[i] = checked_signal_sizes[i - 1];
|
|
}
|
|
zero_grad_shape[signal_ndim] = zero_length;
|
|
zero_grad_shape[signal_ndim + 1] = 2;
|
|
complex_full_grad = at::cat({ grad, at::zeros(zero_grad_shape, grad.options()) }, signal_ndim);
|
|
}
|
|
gI = _fft_with_size(complex_full_grad, signal_ndim,
|
|
/* complex_input */ true, /* complex_output */ true,
|
|
!inverse, checked_signal_sizes, normalized,
|
|
/* onesided */ false, complex_full_grad.sizes()).select(-1, 0);
|
|
} else {
|
|
gI = _fft_with_size(grad, signal_ndim, /* complex_input */ true,
|
|
/* complex_output */ true, !inverse,
|
|
checked_signal_sizes, normalized,
|
|
/* onesided */ false, grad.sizes()).select(-1, 0);
|
|
}
|
|
} else if (complex_input && !complex_output && onesided) {
|
|
// Forward is C2R (onesided)
|
|
// Think of onesided C2R irfft as
|
|
// 1. fill the other half by conjugate symmetry
|
|
// 2. inverse C2C ifft
|
|
// 3. discard the complex dimension
|
|
// So backward is
|
|
// 1. R2C rfft (essentially add dummy complex dimension, and dft)
|
|
// 2. accumulate gradient by conjugate symmetry
|
|
// since rfft results follow conjugate symmetry, we only need to
|
|
// double some entries from onesided rfft results, i.e., the ones with
|
|
// their reflected indices also landing out of the onesided range. So
|
|
// consider the index of last dim:
|
|
// i. idx = 0.
|
|
// Reflected to (N - 0) % N = 0. Not doubled.
|
|
// ii 0 < idx < floor(N/2) (last).
|
|
// N > N - idx > ceil(N/2)
|
|
// Reflected to ()
|
|
// iii. idx = floor(N/2) = N/2 (last) when N even.
|
|
// Reflected to (N - N/2) % N = N/2. Not doubled.
|
|
// iv. idx = floor(N/2) = (N-1)/2 (last) when N odd.
|
|
// Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
|
|
// Therefore, needs to double
|
|
// idx = 1, 2, ..., N/2 - 1 when N even
|
|
// idx = 1, 2, ..., (N-1)/2 when N odd
|
|
// that is
|
|
// idx = 1, 2, ..., N - (floor(N/2) + 1)
|
|
// = 1, 2, ..., N - onesided_length
|
|
gI = _fft_with_size(grad, signal_ndim, /* complex_input */ false,
|
|
/* complex_output */ true, /* inverse */ false,
|
|
checked_signal_sizes, normalized, /* onesided */ true,
|
|
self.sizes());
|
|
int64_t double_length = checked_signal_sizes[signal_ndim - 1] - self.size(signal_ndim);
|
|
if (double_length > 0) { // also covers case when signal size is zero
|
|
gI.narrow(signal_ndim, 1, double_length).mul_(2);
|
|
}
|
|
} else {
|
|
gI = _fft_with_size(grad, signal_ndim, complex_output, complex_input,
|
|
!inverse, checked_signal_sizes, normalized, onesided,
|
|
self.sizes());
|
|
}
|
|
if (normalized) {
|
|
// If normalized, backward is exactly calling fft with inversed argument as
|
|
// the forward because both are unitary.
|
|
return gI;
|
|
} else {
|
|
// If not normalized, in backward, we need to upscale or downscale gI basing
|
|
// on whether the forward is an inverse fft.
|
|
auto signal_numel = std::accumulate(checked_signal_sizes.begin(),
|
|
checked_signal_sizes.end(), 1, std::multiplies<int64_t>());
|
|
if (!inverse) {
|
|
return gI.mul_(static_cast<double>(signal_numel));
|
|
} else {
|
|
return gI.div_(static_cast<double>(signal_numel));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Helper for batchnorm_double_backward
|
|
Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) {
|
|
auto r = to_sum.sum(0, keepdim);
|
|
int64_t start_point_exclusive = keepdim ? 1 : 0;
|
|
for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
|
|
r = r.sum(dim, keepdim);
|
|
}
|
|
return r;
|
|
}
|
|
|
|
// Helper for batchnorm_double_backward
|
|
// similar to expand_as below, but doesn't do the expand_as; operates as if
|
|
// reductions were done with keepdim=True
|
|
Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
|
|
auto src_expanded = src;
|
|
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
|
|
src_expanded = src_expanded.unsqueeze(1);
|
|
}
|
|
if (src_expanded.sizes().size() == target.sizes().size() - 1) {
|
|
src_expanded = src_expanded.unsqueeze(0);
|
|
}
|
|
return src_expanded;
|
|
}
|
|
|
|
// Helper for batchnorm_double_backward
|
|
// because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
|
|
// do a straight expansion because it won't follow the broadcasting rules.
|
|
Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
|
|
auto src_expanded = src;
|
|
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
|
|
src_expanded = src_expanded.unsqueeze(1);
|
|
}
|
|
return src_expanded.expand_as(target);
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
|
|
const Tensor & input,
|
|
const Tensor & gamma,
|
|
const Tensor & ggI,
|
|
const Tensor & ggG,
|
|
const Tensor & ggB,
|
|
const Tensor & gO,
|
|
const Tensor & running_mean,
|
|
const Tensor & running_var,
|
|
bool training,
|
|
double eps,
|
|
const Tensor & save_mean,
|
|
const Tensor & save_invstd,
|
|
std::array<bool,3> output_mask) {
|
|
|
|
bool affine = gamma.defined();
|
|
// TODO: Do we have a ScalarOrTensor type? Would such a thing exist?
|
|
Tensor gamma_expanded;
|
|
Tensor ggG_expanded, ggB_expanded;
|
|
if (affine) {
|
|
gamma_expanded = expand_as_dim1(gamma, input);
|
|
if (ggG.defined()) {
|
|
ggG_expanded = expand_as_dim1(ggG, input);
|
|
}
|
|
if (ggB.defined()) {
|
|
ggB_expanded = expand_as_dim1(ggB, input);
|
|
}
|
|
} else {
|
|
gamma_expanded = at::ones({}, input.options());
|
|
}
|
|
|
|
// define some terms we will reuse
|
|
auto M = input.size(0);
|
|
for (auto s : input.sizes().slice(2)) {
|
|
M *= s;
|
|
}
|
|
// for half inputs, save_mean, save_invstd are float (ideally, we would cast
|
|
// everything else, but not now)
|
|
auto mu = unsqueeze_dim1(training ? save_mean.to(input.type().scalarType()) : running_mean, input);
|
|
auto input_sub_mu = input - mu;
|
|
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(training ? save_invstd.to(input.type().scalarType())
|
|
: running_var.add(Scalar(eps)).pow(-0.5), input);
|
|
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
|
|
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
|
|
|
|
// calculate gI
|
|
auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
|
|
auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
|
|
auto gO_sum = sum_exclude_dim1(gO);
|
|
|
|
Tensor gI;
|
|
if (ggI.defined() && training) {
|
|
auto ggI_sum = sum_exclude_dim1(ggI);
|
|
auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
|
|
auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_(
|
|
(sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M));
|
|
auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
|
|
auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
|
|
auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
|
|
gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
|
|
}
|
|
|
|
// add contribution of gamma term to gI
|
|
Tensor gI_G_term;
|
|
if (affine && ggG.defined()) {
|
|
if (training) {
|
|
auto t0 = gO * sigma2_eps_neg_1_2;
|
|
auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
|
|
auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M);
|
|
gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
|
|
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
|
|
} else {
|
|
gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
|
|
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
|
|
}
|
|
}
|
|
|
|
// this is the first backward's grad_input
|
|
auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor {
|
|
auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
|
|
auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_(
|
|
input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu));
|
|
return h0 * h1;
|
|
};
|
|
|
|
// calculate gG
|
|
Tensor gG;
|
|
if (affine && ggI.defined()) {
|
|
if (training) {
|
|
// gG is just the first backwards with the gamma term removed (then shaped properly)
|
|
gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
|
|
gG = sum_exclude_dim1(gG, false);
|
|
} else {
|
|
gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
|
|
}
|
|
}
|
|
|
|
// calculate ggO
|
|
Tensor ggO;
|
|
// contribution of input term
|
|
if (ggI.defined()) {
|
|
if (training) {
|
|
ggO = first_back_grad_input(ggI, gamma_expanded);
|
|
} else {
|
|
ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
|
|
}
|
|
}
|
|
if (ggG.defined()) {
|
|
auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
|
|
ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
|
|
}
|
|
if (ggB.defined()) {
|
|
auto ggO_B_term = ggB_expanded;
|
|
ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
|
|
}
|
|
|
|
if (output_mask[0] && !ggO.defined()) ggO = at::zeros_like(gO);
|
|
if (output_mask[1] && !gG.defined()) {
|
|
AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
|
|
gG = at::zeros_like(gamma);
|
|
}
|
|
if (output_mask[2] && !gI.defined()) gI = at::zeros_like(input);
|
|
|
|
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
|
|
|
|
}
|
|
|
|
std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3,
|
|
IntList expand1, IntList expand2, IntList expand3,
|
|
IntList sumdim, int64_t unroll_dim, std::array<bool, 3> grad_mask) {
|
|
Tensor grad_i1, grad_i2, grad_i3;
|
|
if (grad_mask[0])
|
|
grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1);
|
|
if (grad_mask[1])
|
|
grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2);
|
|
if (grad_mask[2])
|
|
grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3);
|
|
return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
|
|
}
|
|
|
|
Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
|
|
if (self.is_sparse()) {
|
|
AT_ERROR(
|
|
"log1p of a sparse tensor is made to be non-differentiable since ",
|
|
"local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ",
|
|
"Use a different mathematical operation which preserves sparsity of gradients, ",
|
|
"or report a bug if you think this is an error.");
|
|
}
|
|
return grad / (self + 1);
|
|
}
|
|
|
|
Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices, IntList values_shape) {
|
|
// TODO: improve this backward by writing a kernel (maybe)
|
|
auto dense_grad = sparse_grad_out.is_sparse() ? sparse_grad_out.to_dense() : sparse_grad_out;
|
|
auto full_size = sparse_grad_out.sizes();
|
|
auto flattened_grad_shape = values_shape.vec();
|
|
flattened_grad_shape[0] = at::prod_intlist(full_size.slice(0, indices.size(0)));
|
|
auto flattened_dense_grad = dense_grad.view(flattened_grad_shape);
|
|
auto flattened_indices = at::sparse::flatten_indices(indices, full_size);
|
|
return flattened_dense_grad.index_select(0, flattened_indices);
|
|
}
|
|
|
|
// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
|
|
Tensor constant_pad_nd_backward(const Tensor& grad, IntList pad) {
|
|
auto negated_pad = pad.vec();
|
|
std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate<int64_t>());
|
|
return at::constant_pad_nd(grad, negated_pad, 0);
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
${autograd_function_definitions}
|
|
|
|
}}} // namespace torch::autograd::generated
|