Files
pytorch/tools/autograd/templates/Functions.cpp
Thomas Viehmann 14004cbef6 Native batch norm (#13263)
Summary:
- Move batch norm from TH(CU)NN to native
- Speedups in many cases (e.g. #12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode)
- It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation.
- We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision.

Compared to the ill-fated #12368
- I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own)
- I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`.
- I have merged master
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13263

Differential Revision: D12946254

Pulled By: SsnL

fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
2018-11-06 20:05:54 -08:00

2048 lines
80 KiB
C++

// NB: Must be at the top of file to avoid including the deprecated "math.h".
// https://stackoverflow.com/questions/6563810/m-pi-works-with-math-h-but-not-with-cmath-in-visual-studio
#ifdef _MSC_VER
#define _USE_MATH_DEFINES
#include <cmath>
#endif
#include "Functions.h"
#include <ATen/Utils.h>
#include <ATen/core/TensorOptions.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/core/Reduction.h>
#include <ciso646>
#include <algorithm>
#include <numeric>
#include <functional>
// ${generated_comment}
using at::Tensor;
using at::Scalar;
using at::IntList;
using at::TensorList;
namespace torch { namespace autograd { namespace generated {
namespace {
// Helper functions for autogenerated code
// A simple way to imperatively compute index ranges for slots
// that have been flattened
struct IndexRangeGenerator {
IndexRange range(size_t range_size) {
i += range_size;
return {i - range_size, i};
}
size_t size() { return i; }
private:
size_t i = 0;
};
void copy_range(variable_list& out, IndexRange range, const Tensor & t) {
AT_ASSERT(range.second <= out.size());
AT_ASSERTM(range.second - range.first == 1, "inconsistent range for Tensor output");
out[range.first] = t;
}
void copy_range(variable_list& out, IndexRange range, at::ArrayRef<Tensor> t) {
AT_ASSERT(range.second <= out.size());
AT_ASSERTM(range.second - range.first == t.size(), "inconsistent range for TensorList output");
std::copy(t.begin(), t.end(), out.begin() + range.first);
}
Tensor not_implemented(const char* name) {
throw std::runtime_error(
std::string("the derivative for '") + name + "' is not implemented");
}
Tensor maybe_multiply(const Tensor & t, const Scalar & s) {
bool is_one = false;
if (s.isFloatingPoint()) {
is_one = s.toDouble() == 1;
} else if(s.isIntegral()) {
is_one = s.toLong() == 1;
}
if (is_one) {
return t;
} else {
return t * s;
}
}
int64_t _safe_size(IntList sizes, int64_t dim) {
dim = at::maybe_wrap_dim(dim, sizes.size());
return sizes.size() != 0 ? sizes[dim] : 1;
}
Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_, const Tensor & norm) {
double p = p_.toDouble();
Tensor self_scaled;
Tensor scale_v;
if (p == 0.0) {
return zeros_like(self);
} else if (p == 1.0) {
return self.sign() * grad;
} else if (p == 2.0) {
self_scaled = self;
scale_v = grad / norm;
} else if (std::isinf(p)) {
self_scaled = self.sign() * (self.abs() == norm).toType(self.type());
scale_v = grad.clone();
} else if (p < 2.0) {
self_scaled = self.sign() * self.abs().pow(p - 1);
scale_v = grad / norm.pow(p - 1);
} else {
self_scaled = self * self.abs().pow(p - 2);
scale_v = grad / norm.pow(p - 1);
}
// handle case at 0 where we return a subgradient containing 0
scale_v.masked_fill_(norm == 0, 0);
return self_scaled * scale_v;
}
Tensor norm_backward(Tensor grad, const Tensor & self, const Scalar & p_, Tensor norm, int64_t dim, bool keepdim) {
if (!keepdim && self.dim() != 0) {
grad = grad.unsqueeze(dim);
norm = norm.unsqueeze(dim);
}
return norm_backward(grad, self, p_, norm);
}
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
double exponent = exponent_.toDouble();
if (exponent == 0.0) {
return zeros_like(self);
} else {
return grad * exponent * self.pow(exponent - 1);
}
}
Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & exponent) {
return at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * exponent * self.pow(exponent - 1));
}
Tensor pow_backward_exponent(Tensor grad, const Tensor & self, const Tensor & exponent) {
return grad * self.pow(exponent) * self.log();
}
Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor & exponent) {
return grad * at::pow(base, exponent) * std::log(base.toDouble());
}
Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) {
Tensor args = at::arange(-p + 1, 1, -1, self.options()).div_(2.);
args = args.add(self.unsqueeze(-1));
return grad * args.digamma_().sum(-1).add_(p * (p - 1) * std::log(M_PI) / 4.);
}
Tensor permute_backwards(const Tensor & grad, IntList fwd_dims) {
// invert the permutation
auto ndims = fwd_dims.size();
std::vector<int64_t> dims(ndims);
for (size_t i = 0; i < ndims; i++) {
dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
}
return grad.permute(dims);
}
Tensor sum_backward(const Tensor & grad, IntList sizes, IntList dims, bool keepdim) {
if (!keepdim && sizes.size() > 0) {
if (dims.size()==1) {
return grad.unsqueeze(dims[0]).expand(sizes);
} else {
auto dims_to_unsqueeze = dim_list_to_bitset(dims, sizes.size());
Tensor res = grad;
for (size_t i = 0; i < sizes.size(); i++){
if (dims_to_unsqueeze[i]) {
res = res.unsqueeze(i);
}
}
return res.expand(sizes);
}
} else {
return grad.expand(sizes);
}
}
Tensor reverse_dim(const Tensor& t, int64_t dim) {
Tensor index = at::arange(t.size(dim) - 1, -1, -1, t.options().dtype(at::kLong));
return t.index_select(dim, index);
}
Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t dim) {
if (inp.size(dim) == 1) {
return grad;
}
auto ones_size = inp.sizes().vec();
ones_size[dim] = 1;
Tensor ones = at::ones(ones_size, grad.options());
Tensor exclusive_normal_nocp = at::cat({ones, inp.narrow(dim, 0, inp.size(dim) - 1)}, dim);
Tensor exclusive_normal = exclusive_normal_nocp.cumprod(dim);
Tensor narrow_reverse = reverse_dim(inp.narrow(dim, 1, inp.size(dim) - 1), dim);
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);
return grad * (exclusive_normal * exclusive_reverse);
}
// note that the gradient for prod is equivalent to:
// cumprod(exclusive, normal) * cumprod(exclusive, reverse), e.g.:
// input: [ a, b, c]
// cumprod(exclusive, normal): [1 , a, a * b]
// cumprod(exclusive, reverse): [b * c, c, 1]
// product: [b * c, a * c, a * b]
// and this is safe under input with 0s.
Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& result) {
if (input.dim() == 0) {
return grad;
}
Tensor zero_idx = (input == 0).nonzero();
if (zero_idx.numel() == 0) {
return (grad * result) / input;
} else if (zero_idx.size(0) > 1) {
return zeros_like(input);
} else {
return prod_safe_zeros_backward(grad, input.contiguous().view(-1), 0).view_as(input);
}
}
Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t dim, bool keepdim) {
if (input.dim() == 0) {
return grad;
}
dim = at::maybe_wrap_dim(dim, input.sizes().size());
if (!keepdim && input.dim() != 1) {
grad = grad.unsqueeze(dim);
result = result.unsqueeze(dim);
}
Tensor zero_mask = (input == 0);
Tensor slice_zero_count = zero_mask.sum(dim, true);
int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
if (total_zeros == 0) {
return (grad * result) / input;
} else {
return prod_safe_zeros_backward(grad, input, dim);
}
}
Tensor sum_scan_exclusive(const Tensor& x, int64_t dim) {
Tensor ret = at::cumsum(-x, dim);
int64_t end_idx = ret.size(dim) - 1;
Tensor ret_sum = ret.narrow(dim, end_idx, 1).clone();
ret -= ret_sum.expand_as(ret);
ret += x;
return ret;
}
Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim) {
/*
There are two algorithms to do this. The first one
is very efficient, but works only when there are no
nonzero elements in the input.
The second one is much more complex, but it doesn't
assume anything on the input. The main downside is
that it takes time O(n^2), where n = input.size(self.dim)
(i.e. the length of the cumulative product). This is in
contrast to the forward pass and the efficient algorithm,
which are both O(n).
The second algorithm is a simple application of the chain
rule. If x is an n-dimensional vector, and y = cumprod(x),
and F is the final cost, then
dF / dx_k = sum_j (dF / dy_j) * (dy_j / dx_k) (1)
The term dF / dy_j is just grad_output[j] (assuming again
everything is one-dimensional).
The term (dy_j / dx_k) is easilly seen to be
if j >= k
dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i
else:
dy_j / dx_k = 0
Note that the indicator (j>=k) can be taken out
by replacing the sum in (1) with a sum from
j = k to n.
Thus,
df / dx_k = sum_{k <= j <= n} grad_output[j] * (dy_j / dx_k)
with
dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i (2)
Note that this last term is just the cumulative product
with k omitted. Thus, if x_k (the input) is nonzero, we can
just express this as
dy_j / dx_k = (prod_{1 <= i <= j} x_i) / x_k
= y_j / x_k
So therefore,
df / dx_k = sum_{k <= j <= n} grad_output[j] * y_j / x_k
so
grad_output = sum_scan_exclusiv(grad_output * output) / input
If the input is nonzero, we need to calculate the dy_j / dx_k
by using the formula (2), called in the code omitted_products.
The way the code calculates it is simply by noting that
prod_{1 <= i <= j, i != k} x_i
= (prod_{1 <= i <= k} x_i) * (prod_{k + 1 <= i <= j} x_i)
the first term is calculated as prods_until_k, which since
doesn't depend in j is easy to vectorize.
The second term (indexed by j) is the cumulative product of
x_{k+1}, x_{k+2}, ..., x_n, and it's named in the code
prods_from_k_pkus_1, and it's calculated as a cumprod.
In order to vectorize this properly, we need to add to
omitted_products the dimensions where k > j, and therefore
dy_j / dx_k = 0, which is done right after the assert.
*/
if (input.dim() == 0) {
return grad;
}
dim = at::maybe_wrap_dim(dim, input.sizes().size());
int64_t dim_size = input.size(dim);
if (dim_size == 1) {
return grad;
}
// Simple case with nonzero elements in the input
if ((input != 0).all().item<uint8_t>()) {
Tensor result = at::cumprod(input, dim);
return sum_scan_exclusive(result * grad, dim) / input;
}
auto ones_size = input.sizes().vec();
ones_size[dim] = 1;
Tensor ones = at::ones({1}, grad.options()).expand(ones_size);
Tensor grad_input = at::zeros(input.sizes(), grad.options());
Tensor prods_from_k_plus_1;
Tensor omitted_products;
for (int k = 0; k < dim_size; ++k) {
if (k == 0) {
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k + 1), dim);
omitted_products = at::cat({ones, prods_from_k_plus_1}, dim);
} else if (k == dim_size - 1) {
Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
omitted_products = prods_until_k;
} else {
Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k+1), dim);
omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
omitted_products = at::cat({prods_until_k, omitted_products}, dim);
}
// At this point omitted_products is the same size
// as input, except on the dimension dim where it's
// dim_size - k
AT_ASSERT(omitted_products.size(dim) == dim_size - k);
grad_input.select(dim, k).copy_(
at::sum(grad.slice(dim, k) * omitted_products,dim));
}
return grad_input;
}
Tensor cumprod_backward(const Tensor &grad, const Tensor &input, int64_t dim, ScalarType dtype) {
return cumprod_backward(grad.to(input.scalar_type()), input, dim);
}
Tensor gesv_backward_self(const Tensor & grad, const Tensor & self, const Tensor & A) {
return std::get<0>(at::gesv(grad, A.transpose(-2, -1)));
}
Tensor gesv_backward_A(const Tensor & grad, const Tensor & self, const Tensor & A, const Tensor & solution) {
Tensor grad_self = gesv_backward_self(grad, self, A);
if (self.ndimension() == 2 && A.ndimension() == 2) {
return -at::mm(grad_self, solution.transpose(-2, -1));
}
return -at::matmul(grad_self, solution.transpose(-2, -1));
}
Tensor cumsum_backward(const Tensor & x, int64_t dim) {
if (x.dim() == 0) {
return x;
}
auto ret = at::cumsum(-x, dim);
auto ret_sum = ret.narrow(dim, ret.size(dim) - 1, 1).clone();
ret -= ret_sum.expand(ret.sizes());
ret += x;
return ret;
}
Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) {
return cumsum_backward(x.to(input_dtype), dim);
}
Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, int64_t dim, bool keepdim) {
if (!keepdim && self.dim() != 0) {
grad = grad.unsqueeze(dim);
result = result.unsqueeze(dim);
}
return grad * (self - result).exp();
}
Tensor unbind_backward(const variable_list& grads, int64_t dim) {
IntList sizes;
at::TensorOptions o;
for (auto v : grads) {
if (v.defined()) {
sizes = v.sizes();
o = static_cast<Tensor>(v).options();
break;
}
}
auto grads_tensors = fmap(grads, [&](const Variable &v) { return (v.defined() ? static_cast<Tensor>(v): at::zeros({}, o).expand(sizes));});
return at::stack(grads_tensors, dim);
}
Tensor unsqueeze_to(const Tensor & self, IntList sizes) {
auto result = self;
int64_t nDims = sizes.size();
for (int64_t dim = 0; dim < nDims; dim++) {
if (sizes[dim] == 1) {
result = result.unsqueeze(dim);
}
}
return result;
}
Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntList sizes) {
dim = at::maybe_wrap_dim(dim, sizes.size());
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
// unsqueezing in the backward.
if (sizes.size() > 0 && sizes[dim] == 1) {
return self.unsqueeze(dim);
}
return self;
}
std::vector<Tensor> cat_tensors_backward(const Tensor & grad, const std::vector<std::vector<int64_t>> &sizes, int64_t dim) {
dim = at::legacy_cat_wrap_dim(dim, sizes);
std::vector<Tensor> grad_inputs(sizes.size());
int64_t accumulate = 0;
for (size_t i = 0; i < sizes.size(); ++i) {
auto& shape = sizes[i];
// If input was empty tensor, gradInput should be empty tensor.
if (shape == std::vector<int64_t>({0})) {
grad_inputs[i] = at::zeros({0}, grad.options());
continue;
}
auto size = shape[dim];
accumulate += size;
grad_inputs[i] = grad.narrow(dim, accumulate - size, size);
}
return grad_inputs;
}
Tensor clamp_backward(const Tensor & grad, const Tensor &self, const optional<Scalar> & min, const optional<Scalar> & max) {
// clamp: gradients not defined on min and max, so we return the subgradient 1 for these cases.
if (max && min) {
return grad * ((self >= *min) * (self <= *max)).type_as(grad);
} else if (min) {
return grad * (self >= *min).type_as(grad);
} else if (max) {
return grad * (self <= *max).type_as(grad);
} else {
return grad;
}
}
Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, IntList sizes, IntList strides, const Scalar & alpha) {
// if input was column-major, return grad as column-order for efficiency
if (strides[0] == 1 && strides[1] == sizes[0]) {
return maybe_multiply(mat2.mm(grad.t()).t(), alpha);
} else {
return maybe_multiply(grad.mm(mat2.t()), alpha);
}
}
Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, const Tensor & mat1, const Scalar & alpha) {
// if input was column-major, return grad as column-order for efficiency
if (mat1.is_sparse()) {
throw std::runtime_error("calculating the gradient of a sparse Tensor argument to mm is not supported.");
}
at::IntList sizes = mat1.sizes();
at::IntList strides = mat1.strides();
if (strides[0] == 1 && strides[1] == sizes[0]) {
return maybe_multiply(mat2.mm(grad.t()).t(), alpha);
} else {
return maybe_multiply(grad.mm(mat2.t()), alpha);
}
}
Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntList sizes, IntList strides, const Scalar & alpha) {
// if input was column-major, return grad as column-order for efficiency
if (strides[0] == 1 && strides[1] == sizes[0]) {
return maybe_multiply(grad.t().mm(mat1).t(), alpha);
} else {
return maybe_multiply(mat1.t().mm(grad), alpha);
}
}
Tensor renorm_backward(const Tensor & grad, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm) {
auto transposed_sizes = self.transpose(dim, 0).sizes().vec();
auto flatten = [&](const Tensor & t) {
return t.transpose(dim, 0).contiguous().view({t.size(dim), -1});
};
auto unflatten = [&](const Tensor & t) {
return t.contiguous().view(transposed_sizes).transpose(dim, 0);
};
// renorm computes the norm over all dimensions except `dim`, which is why
// we need the flatten and unflatten business. TODO: simplify this when we
// add support for norm over multiple dimensions.
auto self_flat = flatten(self);
auto grad_flat = flatten(grad);
auto norm_flat = self_flat.norm(p, 1, true);
auto grad_output = (self_flat * grad_flat).sum(1, true);
auto nb = norm_backward(grad_output, self_flat, p, norm_flat, 1, true);
auto invnorm = (norm_flat + 1e-7).reciprocal();
auto grad_norm = unflatten(maxnorm * invnorm * (grad_flat - invnorm * nb));
auto norm = unflatten(norm_flat.expand_as(self_flat));
// TODO: remove the detach once comparison ops no longer require grad
auto mask = Variable(norm < maxnorm).detach();
return at::where(mask, grad, grad_norm);
}
Tensor sum_tensorlist(TensorList tl) {
if (tl.size() == 0) {
throw std::runtime_error("Can't sum tensorlist of size 0");
}
Tensor sum = tl[0];
for(size_t i = 1; i < tl.size(); ++i) {
sum = sum + tl[i];
}
return sum;
}
Tensor repeat_backward(Tensor grad, int64_t input_dims, IntList repeats) {
int64_t num_unsqueezed = grad.dim() - input_dims;
for (int64_t i = 0; i < num_unsqueezed; ++i) {
grad = grad.sum(0, false);
}
for (size_t j = num_unsqueezed; j < repeats.size(); ++j) {
int64_t repeat = repeats[j];
if (repeat == 1) {
continue;
}
int64_t dim = j - num_unsqueezed;
grad = sum_tensorlist(grad.chunk(repeat, dim));
}
return grad;
}
// p1m == 1 - p
Tensor _fused_dropout_backward(Tensor grad, Tensor mask, double p1m) {
if (grad.requires_grad()) {
// Use autograd-friendly backward if double backward is required
return grad * (mask.type_as(grad) * (1. / p1m));
} else {
return at::_masked_scale(grad, mask, 1. / p1m);
}
}
Tensor select_equals_backward(Tensor grad, const Tensor & input, const Tensor & value) {
auto grad_input = zeros_like(input);
grad_input.masked_fill_(input == value, grad);
return grad_input;
}
Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntList sizes, bool keepdim) {
if (!keepdim && sizes.size() > 0) {
grad = grad.unsqueeze(dim);
indices = indices.unsqueeze(dim);
}
return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad);
}
Tensor slice_backward(Tensor grad, IntList input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.slice(dim, start, end, step).copy_(grad);
return grad_input;
}
Tensor select_backward(Tensor grad, IntList input_sizes, int64_t dim, int64_t index) {
auto grad_input = at::zeros(input_sizes, grad.options());
grad_input.select(dim, index).copy_(grad);
return grad_input;
}
Tensor trace_backward(const Tensor & grad, IntList sizes) {
if (sizes.size() != 2) {
throw std::runtime_error("expected matrix input");
}
auto grad_input = at::zeros(sizes[0] * sizes[1], grad.options());
auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong));
grad_input.index_fill_(0, indices, grad);
return grad_input.view(sizes);
}
Tensor unfold_backward(const Tensor & grad, IntList input_sizes, int64_t dim, int64_t size, int64_t step) {
int64_t numel = 1;
for (auto size : input_sizes) {
numel *= size;
}
auto idx = at::arange(0, numel, grad.options().dtype(at::kLong)).view(input_sizes);
auto idx_unfolded = idx.unfold(dim, size, step).contiguous().view(-1);
auto grad_input = at::zeros({numel}, grad.options());
grad_input.index_add_(0, idx_unfolded, grad.contiguous().view(-1));
return grad_input.view(input_sizes);
}
Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
}
Tensor var_backward(Tensor grad, const Tensor & self, int64_t dim, bool unbiased, bool keepdim) {
if (self.dim() == 0) {
return var_backward(grad, self, unbiased);
}
if (!keepdim && self.dim() > 1) {
grad = grad.unsqueeze(dim);
}
return (2.0 / (self.size(dim) - unbiased)) * grad * (self - self.mean(dim, true));
}
Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntList sizes) {
int64_t numel = 1;
for (auto size : sizes) {
numel *= size;
}
auto mask_selected = grad.masked_select(mask);
auto diff_nelem = numel - mask_selected.numel();
if (diff_nelem > 0) {
// because mask_selected returns a 1-d tensor with size of masked elements that are 1,
// we need to fill out the rest with zeros then reshape back to tensor2's size.
auto zeros_fillin = at::zeros({diff_nelem}, grad.options());
mask_selected = at::cat({mask_selected, zeros_fillin}, 0);
}
return mask_selected.view(sizes);
}
Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
// cf. Iain Murray (2016); arXiv 1602.07527
if (upper) {
L = L.t();
grad = grad.t();
}
auto phi = [](const Tensor & A) -> Tensor {
auto B = A.tril();
B = B - 0.5 * at::diag(at::diag(B));
return B;
};
// make sure not to double-count variation, since
// only half of output matrix is unique
auto Lbar = grad.tril();
auto P = phi(at::mm(L.t(), Lbar));
Tensor S;
std::tie(S, std::ignore) = at::gesv(P + P.t(), L.t());
std::tie(S, std::ignore) = at::gesv(S.t(), L.t());
S = phi(S);
if (upper) {
S = S.t();
}
return S;
}
Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
IntList split_sizes, int64_t dim, IntList sizes, const Type &type) {
dim = at::maybe_wrap_dim(dim, sizes.size());
// it's possible some of the grads are not defined (represents tensors of all 0s).
// Since at::cat can't handle those, let's define them
std::vector<Tensor> grads_all_defined(grads.size());
for (size_t j = 0; j < grads.size(); ++j) {
if (grads[j].defined()) {
grads_all_defined[j] = grads[j];
} else {
auto length = split_sizes[j];
auto grad_size = sizes.vec();
grad_size[dim] = length;
grads_all_defined[j] = at::zeros(grad_size, type);
}
}
auto ret = at::cat(grads_all_defined, dim);
return ret;
}
Tensor split_backward(const std::vector<torch::autograd::Variable> &grads,
int64_t split_size, int64_t dim, IntList sizes, const Type &type) {
dim = at::maybe_wrap_dim(dim, sizes.size());
int64_t dim_size = sizes[dim];
int64_t num_splits = grads.size();
std::vector<int64_t> split_sizes(num_splits, split_size);
split_sizes[num_splits - 1] = split_size - (split_size * num_splits - dim_size);
return split_with_sizes_backward(grads, split_sizes, dim, sizes, type);
}
Tensor max_pool_double_backward(const Tensor & grad, const Tensor & indices, int dim) {
AT_ASSERT(indices.dim() >= dim);
auto size = indices.sizes().slice(0, indices.dim() - dim).vec();
size.push_back(-1);
auto indices_view = indices.view(size);
return grad.contiguous().view(size).gather(-1, indices_view).view(indices.sizes());
}
Tensor glu_double_backward(const Tensor & grad, const Tensor & grad_output, const Tensor & input, int64_t dim) {
auto& gO = grad_output;
auto input_size = input.size(dim) / 2;
auto first_half = input.narrow(dim, 0, input_size);
auto second_half = input.narrow(dim, input_size, input_size);
auto sig_second_half = second_half.sigmoid();
auto one_sub_sig_second_half = 1 - sig_second_half;
auto sig_one_sub_sig = sig_second_half * one_sub_sig_second_half;
auto ggI_first_half = grad.narrow(dim, 0, input_size);
auto ggI_second_half = grad.narrow(dim, input_size, input_size);
auto ggI_second_half_times_first_half = ggI_second_half * first_half;
auto gI_first_half = ggI_second_half * gO * sig_one_sub_sig;
auto second_order_sh = sig_one_sub_sig * one_sub_sig_second_half - sig_second_half * sig_one_sub_sig;
auto gI_second_half = ggI_second_half_times_first_half * gO * second_order_sh + ggI_first_half * gO * sig_one_sub_sig;
return at::cat({gI_first_half, gI_second_half}, dim);
}
Tensor glu_double_backward_grad_output(const Tensor & grad, const Tensor & input, int64_t dim) {
if (dim < 0) dim += input.dim();
auto sizes = input.sizes().vec();
sizes[dim] /= 2;
auto tmp = grad * glu_backward(at::ones(sizes, input.options()), input, dim);
return tmp.narrow(dim, 0, sizes[dim]) + tmp.narrow(dim, sizes[dim], sizes[dim]);
}
Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto result = kl_div_backward(grad, input, target, Reduction::None);
if (reduction == Reduction::Mean) {
return result.mean();
} else if (reduction == Reduction::Sum) {
return result.sum();
}
return result;
}
// Compute derivatives for targets.
// Assume targets are given as probabilities (i.e. without taking the logarithm).
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction) {
if (reduction == Reduction::None) {
return grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
}
if (reduction == Reduction::Mean) {
return grad_output.mul(target.log().add_(1).sub_(self)).div_(target.numel()).masked_fill_(target == 0, 0.);
}
return grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
}
Tensor binary_cross_entropy_with_logits_target_backward(const Tensor& grad_output, const Tensor& self, const Tensor& target, const Tensor& weight, const Tensor& pos_weight, int64_t reduction) {
Tensor grad_target;
if (pos_weight.defined()) {
grad_target = (1. - self.sigmoid()).log_().sub_(pos_weight.mul(self.sigmoid().log_())).mul_(grad_output);
} else {
grad_target = self.mul(-grad_output);
}
if (weight.defined()) {
grad_target.mul_(weight);
}
if (reduction == Reduction::Mean) {
grad_target.div_(target.numel());
}
return grad_target;
}
Tensor log_sigmoid_double_backward(const Tensor & grad, const Tensor & input) {
auto z = input.sigmoid();
return grad * (z - 1) * z;
}
Tensor softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
auto gO = grad_output;
auto ggI = grad;
auto ggI_output = ggI * output;
auto ggI_out_sum = ggI_output.sum(dim, true);
auto ggI_out_sum_output = ggI_out_sum * output;
auto gO_out_sum = (gO * output).sum(dim, true);
// gI calculation
auto gI_t0 = ggI_output * (gO - gO_out_sum);
auto gI_t1 = output * ((ggI_output * gO).sum(dim, true).sub_(gO_out_sum * ggI_out_sum));
auto gI_t2 = ggI_out_sum_output * gO;
auto gI_t3 = ggI_out_sum_output * gO_out_sum;
return gI_t0 - gI_t1 - gI_t2 + gI_t3;
}
Tensor log_softmax_double_backward(const Tensor & grad, const Tensor & grad_output, int dim, const Tensor & output) {
auto z = output.exp();
return z * grad_output.sum(dim, true) * ((grad * z).sum(dim, true) - grad);
}
Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto output = l1_loss_backward(grad, input, target, Reduction::None);
if (reduction == Reduction::Mean) {
return output.mean();
} else if (reduction == Reduction::Sum) {
return output.sum();
}
return output;
}
Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto d = (input - target).abs();
auto grad_input = grad * (d < 1).toType(grad.type());
if (reduction == Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
if (reduction == Reduction::None) {
return smooth_l1_loss_backward(grad, input, target, reduction);
}
auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction);
return (r * grad).sum();
}
Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal) {
auto ndimension = input_sizes.size();
AT_ASSERT(ndimension == 1 || ndimension == 2);
if (ndimension == 1 || input_sizes[0] == input_sizes[1]) {
return grad.diag(diagonal);
}
// Input was a matrix but was not square
auto grad_input = at::zeros(input_sizes, grad.options());
auto diag = grad_input.diagonal(diagonal);
diag.copy_(grad);
return grad_input;
}
Tensor diagonal_backward(const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
auto grad_input = at::zeros(input_sizes, grad.options());
auto diag = grad_input.diagonal(offset, dim1, dim2);
diag.copy_(grad);
return grad_input;
}
Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, int64_t reduction) {
auto grad_input = 2 * grad;
if (reduction == Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor mse_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
if (reduction == Reduction::None) {
return mse_loss_backward(grad, input, target, reduction);
}
auto r = mse_loss_backward(ones_like(grad_output), input, target, reduction);
return (r * grad).sum();
}
Tensor soft_margin_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
auto z = (input * -target).exp();
auto zplus1 = z + 1;
auto grad_input = grad * (target * target) * z / (zplus1 * zplus1);
if (reduction == Reduction::Mean) {
grad_input /= input.numel();
}
return grad_input;
}
Tensor soft_margin_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
if (reduction == Reduction::None) {
return soft_margin_loss_backward(grad, input, target, reduction);
}
auto r = soft_margin_loss_backward(ones_like(grad_output), input, target, reduction);
return (r * grad).sum();
}
Tensor softplus_double_backward(const Tensor & grad, const Tensor & input, Scalar beta, Scalar threshold) {
auto x = (input * beta);
return _sigmoid_backward(grad, x.sigmoid()) * (x < threshold).toType(grad.type()) * beta;
}
// NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
//
// `storage_offset` is ignored for simplicity in this note. If you just want the
// full algorithm without explanation, scroll down to bottom of this note.
//
// Implementing the backward of as_strided is tricky because you have to deal
// with mappings that map one memory location to multiple indices, i.e., the
// output tensor has multiple indices pointing to **overlapping** memory
// addresses. This can happen in all in all sorts of weird cases. For example,
//
// x = torch.randn(15)
// x.as_strided([3, 3], [1, 0]) # "expand" case
// x.as_strided([3, 3], [2, 1]) # "size too large" case
// x.as_strided([3, 2], [3, 6]) # res[2, 0] points to 2*3 + 0*6 = 6
// # res[0, 1] points to 0*3 + 1*6 = 6
//
// Here is the general strategy we apply in implementing as_strided backward:
// 0. ??? (optimization step. we will talk about this later)
// 1. Create some underlying flattened tensor as if it is the base tensor
// representing the contiguous memory storage for both input and output.
// 2. Use the output geometry to scatter (or index_add) the gradients into
// this storage tensor.
// 3. ??? (fix for input tensor with overlapping memory. we will talk about
// this later)
// 4. Return the as_strided view of the storage tensor using input geometry.
//
// In step (2), if the output tensor does't have overlapping memory, we can
// safely scatter (`storage.as_strided(output_geometry).copy_(grad)`);
// otherwise, we must use `index_add` as gradients at different indices may need
// to be summed to a single location.
//
// For example, in this case:
//
// x = torch.randn(3)
// y = x.as_strided([3, 3], [1, 0]) # "expand" case
// # size [ 3, 3]
// # stride [ 1, 0]
// y.backward() # step (1): contiguous storagte tensor `s` of size 3, which
// is large enough to be used as underlying storage
// for `x` and `y`.
// s = [ 0, 0, 0]
// # step (2): since `y` has overlapping memory, index_add grad
// into `s` basing on `y`'s geometry, i.e.,
// s[i * y.stride(0) + j * y.stride(1)] += gy[i, j].
// s = [ 3, 3, 3]
// # step (4): as_strided view `s` using `x`'s geometry
// s = [ 3, 3, 3]
// grad_input = s.as_strided(x.size(), x.stride())
// = s.as_strided([3], [1])
// = [ 3, 3, 3]
//
// This is exactly what we would get if using `expand`. However, here the input
// tensor doesn't have overlapping memory. If it does, we must add an extra step
// before (4). Considering this case:
//
// t = torch.randn(3)
// x = t.expand(3, 3) # input with overlapping memory
// # size [3, 3]
// # stride [0, 1]
// y = x.as_strided([1], [1]) # contiguous output
// # size [1]
// # stride [1]
// y.backward() # step (1): contiguous storage tensor `s` of size 3, which
// is large enough to be used as underlying storage
// for `x` and `y`.
// s = [ 0, 0, 0]
// # step (2): scatter grad into `s` basing on `y`'s geometry
// s = [ 1, 0, 0]
// # step (4): as_strided view `s` using `x`'s geometry
// s = [ 1, 0, 0]
// grad_input = s.as_strided([3, 3], [0, 1])
// = s.as_strided([3, 3], [0, 1])
// = [[ 1, 0, 0],
// [ 1, 0, 0],
// [ 1, 0, 0]]
// Is this result correct?
//
// `x.as_strided([1], [1])` call is obviously equivalent with
// `x[(0,) * x.dim()].view(1)` for any `x`. But autograd through the second
// gives gradient `[ [ 1, 0, 0], [ 0, 0, 0], [ 0, 0, 0]]`. For this specific
// case, indexing `x` at any index in first column is also equivalent, and
// yields a gradient of shape `[3 x 3]` containing eight 0's and one 1. There is
// an `x.size(1)`-times difference between these gradients computed from other
// PyTorch ops and the gradient we got from as_strided.
//
// You might conclude that the gradients from as_strided is wrong. However,
// let's first see why they are actually reasonable. Consider the pointwise
// perturbations by `delta` anywhere in the first column of `x`. It will lead to
// a `delta` change in the same memory location, and then `y` will change by
// `delta`. So one can say the gradient should be exactly 1 at the first column,
// as given by our above procedure.
//
// In the above computation of numerical gradients, they only match the
// analytical results because strides and memory locations are considered in the
// forward pass, i.e., this op (including both forward and backward) is
// layout-aware.
//
// However, in PyTorch, most (probably all) other ops (forward and backward) are
// layout-agnostic. E.g.,
//
// t = torch.randn(1)
// x = t.expand(2)
// y = x.sum()
// y.backward()
//
// Layout-agnostic autograd (as it is currently in PyTorch) will give you
//
// gy = 1
// gx = [ 1, 1] # SumBackward: torch.ones_like(x)
// gt = [ 2] # ExpandBackward: gx.sum()
//
// Note that `gx = [ 1, 1]`. However, if you perturb any value in `x` by `delta`
// (the other will also change by `delta`), `y` will change by `2 * delta`. So
// the gradients, if strides are taken into consideration, should be 2.
//
// Layout-aware autograd should give you
//
// gy = 1
// gx = [ 2, 2] # Because the backward considers the fact that the input `x`
// # is already expanded.
// gt = [ 2] # Layout-aware backward of expand is just a slicing because
// # the previous backward should have already taken care of
// # strides and made sure that gradients are the same along the
// # expanded dimension.
//
// As shown above, these two types are not compatible. Therefore, we must either
// make as_strided layout-agnostic, or make all other ops layout-aware.
//
// It is difficult to support layout-aware autograd (at least in the current
// codebase structure), because it would mean
// 1. storing tensor geometries of every input tensor for backward
// 2. depending on input geometry, the gradient computed from backward change
// 3. ideally enforcing gradient of T to always have same strides as T
// (although these two methods only differ when it comes to overlapping memory)
//
// Therefore, we must formulate `as_strided` in a layout-agnostic way, i.e.,
// giving the same output regardless of the input layout. We consider
// `input.stride()` as a separate independent fixed argument `input_stride`.
// Then, `as_strided(input, size, stride)` can be thought of as:
// 1. "Scatter" each value of `input` into a "storage" using storage location
// computed from the value's index in `input`, `input.size()` and
// `input_stride`, but if N values end up in the same location, the value
// is average of those N values (they will be the same value anyways).
//
// Formal description:
// Denote the set of all input indices that pointing to the same storage
// location `storage[n]` as `S(n)`, i.e.,
//
// S(n) = { index : <index, input_stride> == n, index is valid given input.size() },
//
// where `<x, y>` is the dot product between `x` and `y`.
//
// Then, the process is:
//
// storage[n] = Avg { S(n) }
//
// Note that all values in `S(n)` are the same (they point to the same
// memory location anyways, so this step doesn't change anything, but
// effectively avoids having the denpendency on the layout of `input`.
// I.e., the result holds fixed regardless of the layout of `input`, as
// long as `input_stride` is fixed.
//
// NOTE: for forward pass, we can equivalently simply selet any one of
// `S(n)` as `storage[n]`. However, cosnidering this as an average
// operation makes backward easier (so all values in set
// `{ grad_input[i] : i in S(n) }` are the same, and it can use the
// same geometry as input).
// 2. As usual, return the as_strided view of `storage` using required output
// `size` and `stride`.
//
// To backward through this layout-agnostic version, we simply add the following
// step:
// .... (scatter gradients into the storage tensor using output geometry)
// 3. For all storage location n, `storage[n] /= |S(n)|`.
// .... (return as_strided view of the storage tensor using input geometry)
//
// Finally, we note that these general operations are expensive, so we apply the
// following optimizations:
// Add step (0): For all output dimension `d` with output stride 0, sum the
// gradients along dimension `d` (don't keepdim), and remove
// dimension `d` from output size and stride.
// (An optimization for "expand" cases so we may avoid step (3))
// Only apply step (3) when input tensor has overlapping memory.
//
// FULL ALGORITHM:
// 0. For all output dimension `d` with output stride 0, sum the gradients
// along dimension `d` (don't keepdim), and remove dimension `d` from
// output size and stride.
// 1. Create some underlying flattened tensor as if it is the base tensor
// representing the contiguous memory storage for both input and output.
// 2. Use the output geometry to scatter (or index_add) the gradients into
// this storage tensor `storage`.
// 3. If input tensor has overlapping memory,
// For all storage location `i`, `storage[i] /= N(i)`, where `N(i)` is the
// number of indices in input geometry pointing to the same storage
// location `i` (i.e., `|S(i)|` in equations above).
// 4. Return the as_strided view of the storage tensor using input geometry.
//
// See NOTE [ Detecting Memory Overlap Within A Strided Tensor ] on how to
// roughly detech overlapping memory.
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
//
// Checking memory overlap within a strided tensor is the special case of
// detecting memory overlap of two strided tensors, where the two tensors start
// at the same memory address. The later is HARD (see #8212).
//
// But even this special case isn't simple. This note describes a check for a
// even more constrained simple case where we can be certain that there is no
// overlap.
//
// The checking algorithm can be described as:
// 0. Return [ pass check ] if any dimension has size 0
// 1. Ignore all dimensions that have size 1
// 2. If no remaining dimensions, return [ pass check ]
// 3. Sort the remaining dimensions according to the strides decreasingly
// 4. Check that for each dimension k,
//
// stride[k] > \sum_{ i > k } (size[i] - 1) * stride[i]
//
// That is equivalent to, after reording the dimensions so strides are
// in decreasing order, checking that stride of each dimension is larger
// than the maximum memory offset in a slice at that dimension.
//
// Obviously this check passes for contiguous tensors ( the dimensions will be
// already sorted with LHS = stride[0] = \prod size[i] being exactly 1 larger
// than RHS ). Similarly, the check passes for tensors contiguous in all but
// the last dimension, and LHS = stride[0] = stride[-1] * \prod size[i] being
// exactly stride[-1] larger than RHS. (*)
//
// We will show that these view operations, including all our view operations
// *except for* general as_strided and unfold, also preserve this invariant:
//
// alias: Obviously preserves
//
// expand: All changed dimensions are removed in step (1)
//
// view: Consider the input dimensions as grouped into consecutive
// dimension "blocks", where dimensions are contiguous in each one.
// one. view only works when the output dimensions can also be
// grouped into the same consecutive blocks of same ordering.
//
// NB: this means that the number of elements and stride of the
// last dimension in each block is the same in input and
// output. (**)
//
// Notation:
// Consider a single such block B,
// ... B_prev[-1]], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ B_next[0], ...
// start--^^^^ ^^^^^^^^^^^^--end
// Each B[i] denotes a dimension index such that B[i] = B[0] + i.
//
// We first show that in a tensor (i.e., input) satisfies the
// invariant, after sorting, the dimensions within each block
// still remain consecutive. (***)
//
// After removing dimensions of size 1, the dimensions within a
// block is already sorted by strides in descending order. So
// sorting all dimensions will not change the relative ordering
// among them.
//
// Assume that some block B is not consecutive after sorting,
// i.e., there exists a dimension d between B[0] and B[-1] in
// sorted order.
//
// By (*), we know that
// stride[B[0]]
// = \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[B[-1]]
// < \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + stride[d]
// <= \sum_{i > 0} (size[B[i]] - 1) * stride[B[i]] + (size[d] - 1) * stride[d]
// <= \sum{j > B[0]} (size[j] - 1) * stride[j],
//
// where the first < comes from sorting and
// the second <= comes from the fact that dimension d
// exists after step (1) and
// thus must have size greater
// than 1
// the third <= comes from the fact that each term in
// the sum is non-negative
//
// Then we have a countradiction as the invariant must not be
// satisfied at B[0]. So the original proposition is true.
//
// Now that we established the above claim (***), we consider the
// view operation as first sorting the dimensions (i.e., blocks),
// apply the original view (since it only cares dimensions being
// consecutive and contiguous withtin each block), and then undo
// the sort.
//
// Consider a single block B in the output,
// ... ], [ B[0], ..., B[i], ..., B[k] = B[-1] ], [ ...
// start--^^^^ ^^^^^^^^^^^^--end
//
// By (*), we know that for all i
// stride[i] = stride[B[-1]] +
// \sum_{j=i+1}^{k} (size[B[j]] - 1) * stride[B[j]]
//
// Then the invariant is obviously satisfied at every dimension
// in this block if it is satisfied at dimnesion B[-1]. It only
// remains to show that it is satisfied at the last dimension in
// each block.
//
// Since the same blocks are present in both input and output
// with the same ordering, we will abuse the notation in the
// following statements.
//
// By (*), we know that the following holds for both input and
// output, for any block B:
// \sum_{i > B[-1]} (size[i] - 1) * stride[i]
// = \sum_{block B' after B} \prod_{j in B'} size[B[j]] * stride[B'[-1]]
// = \sum_{block B' after B} numel(B') * stride[B'[-1]].
// ^^^^^^^^^^^^^^^^^^^^^^^|^^^^^^^^^^^^^^^^^^^^^^^^^^
// By (**), we know that, this quantity in the above equation
// remains the same in input and output. So both
// \sum_{i > B[-1]} (size[i] - 1) * stride[i]
// and
// stride[B[-1]]
// are the same in input and output.
//
// These two quantities are exactly the LHS and RHS of the
// invariant inequality. Since by assumption the invariant is
// satisfied in input at B[-1], it is also satisfied in output at
// B[-1]. This concludes the proof.
//
// squeeze: Special case of view
//
// unsqueeze: Special case of view
//
// slice: Consider slicing dimension i with step = k >= 1.
//
// Let stride' and size' be the output strides and sizes. We have
//
// stride'[i] = k * stride[i]
// size'[i] <= floor(size[i] / k)
//
// If size'[i] = 1, invariant is obviously satisfied as we are
// just removing a dimension (afte step (1)).
//
// Assume size'[i] > 1.
//
// By assumption, the invariant is satisfied at every dimension
// in input.
//
// For any dimension j, if stride[j] > stride[i], we have
// stride'[j] = stride[j]
// > (size[i] - 1) * stride[i]
// = (size[i] / k * k - 1) * k * stride[i] / k
// = (size[i] / k - 1 / k) * stride'[i]
// >= (size'[i] - 1 / k) * stride'[i]
// >= stride'[i].
//
// If stride[j] < stride[i], we have
// stride'[j] = stride[j] < stride[i] <= stride'[i].
//
// So the sorting order remains unchanged after slice.
//
// Since
// (size'[i] - 1) * stride'[i]
// = (floor(size[i] / k) - 1) * k * stride[i]
// <= (size[i] / k - 1) * k * stride[i]
// = (size[i] - k) * stride[i]
// <= (size[i] - 1) * * stride[i],
// the term from this dimension i in the invariant inequality at
// other dimensions can only decrease after slice. So the
// invariant is preserved.
//
// narrow: Special case of slice
//
// select: narrow + squeeze
//
// permute: Sorting makes permutation of dimensions irrelevant
//
// transpose: Sorting makes swapping dimensions irrelevant
//
// diagonal: Effectively merging two dimensions i and j into a new
// dimension k s.t.
// stride'[k] = stride[i] + stride[j]
// size'[k] <= min(size[i], size[j]),
// where stride and size are on the input, and stride' and size'
// are on the output.
//
// Assuming that size[i] > 1 and size[j] > 1. If any has size 1,
// then this is unsqueeze on that dimension.
//
// WLOG, say stride[i] >= stride[j].
//
// Each dimension d in input with stride[d] > stride[j] has
// stride'[d] = stride[d]
// > (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j]
// >= stride[i] + stride[j]
// = stride[k].
// So, considering the sorted dimensions, this is effectively
// removing i, and replacing j with k.
//
// For dimensions d with stride[i] < stride[d] < stride[j], the
// term from dimension i is removed in the invariant inequality.
// For dimensions d with stride[d] > stride[j], we have
// (size'[k] - 1) * stride'[k]
// <= (min(size[i], size[j]) - 1) * (stride[i] + stride[j])
// <= (size[i] - 1) * stride[i] + (size[j] - 1) * stride[j],
// so the term from i and j in the invariant can only decrease.
//
// So this is generally relaxing the constraint, and thus it
// preserves it.
// This implements steps (2)~(4) of the algorithm in
// NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// Helper for as_strided_backward
static inline bool _maybe_overlapping_memory(IntList sizes, IntList strides) {
if (sizes.size() > 0) {
std::vector<std::size_t> argsort(sizes.size());
std::iota(argsort.begin(), argsort.end(), 0);
std::sort(argsort.begin(), argsort.end(),
[&](std::size_t i, std::size_t j){ return strides[i] < strides[j]; });
int64_t max_index_in_slice = 0;
for (auto i : argsort) {
auto stride_ = strides[i];
if (stride_ <= max_index_in_slice) {
return true;
}
max_index_in_slice += stride_ * (sizes[i] - 1);
}
}
return false;
}
// Returns the minimum storage size needed to contain a tensor of sizes, strides, and storage_offset
// Helper for as_strided_backward
static inline int64_t _min_storage_size(IntList sizes, IntList strides, int64_t storage_offset) {
int64_t storage_size = storage_offset + 1;
int64_t dim = sizes.size();
for (int64_t i = 0; i < dim; i++) {
auto size_i = sizes[i];
if (size_i == 0) {
return storage_offset;
}
storage_size += (size_i - 1) * strides[i];
}
return storage_size;
}
// See NOTE [ as_strided Backward and layout-aware/agnostic autograd ] for explanation
Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntList sizes, IntList strides, int64_t storage_offset) {
// For output geometry,
// check for size 0 dimensions,
// skip size 1 dimensions,
// reduce grad on expanded dims (stride=0, size>1)
// Step (0) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on output geometry
auto odim = grad.dim();
std::vector<int64_t> out_sizes_, out_strides_;
out_sizes_.reserve(odim);
out_strides_.reserve(odim);
for (int64_t i = odim - 1; i >= 0; i--) {
auto size_i = sizes[i];
auto stride_i = strides[i];
if (size_i == 0) {
return at::zeros(input_geometry.sizes(), grad.options());
} else if (size_i == 1) {
grad = grad.squeeze(i);
} else if (stride_i == 0) {
grad = grad.sum(i, false);
} else {
out_sizes_.insert(out_sizes_.begin(), size_i);
out_strides_.insert(out_strides_.begin(), stride_i);
}
}
// Step (2)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on output geometry
auto out_maybe_overlap = _maybe_overlapping_memory(out_sizes_, out_strides_);
// For input geometry,
// check for size 0 dimensions,
// skip size 1 dimensions,
// Step (0)~(1) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on input geometry
auto idim = input_geometry.dim();
IntList inp_sizes = input_geometry.sizes(), inp_strides = input_geometry.strides();
std::vector<int64_t> inp_sizes_, inp_strides_;
inp_sizes_.reserve(idim);
inp_strides_.reserve(idim);
for (int64_t i = idim - 1; i >= 0; i--) {
auto size_i = inp_sizes[i];
auto stride_i = inp_strides[i];
if (size_i == 0) {
return at::zeros(input_geometry.sizes(), grad.options());
} else if (size_i != 1) {
inp_sizes_.insert(inp_sizes_.begin(), size_i);
inp_strides_.insert(inp_strides_.begin(), stride_i);
}
}
// Step (1)~(4) for the algorithm in NOTE [ Detecting Memory Overlap Within A Strided Tensor ]
// on input geometry
auto inp_maybe_overlap = _maybe_overlapping_memory(inp_sizes_, inp_strides_);
// Rest of this function implements
// Step (1)~(4) for the algorithm in NOTE [ as_strided Backward and layout-aware/agnostic autograd ]
// TODO: Raise if not all output values are visible in input geometry.
// Technically speaking, if you treat those values as constants, not
// raising is fine, and mathematically correct. However, these values
// really are contained in some base tensor, and by treating them as
// constants we are ignoring this tight dependency. Therefore, it is
// more sensible to raise here.
// Step (1): create underlying tensor as "storage"
auto shared_offset = std::min(input_geometry.storage_offset(), storage_offset);
auto inp_effective_offset = input_geometry.storage_offset() - shared_offset;
auto out_effective_offset = storage_offset - shared_offset;
auto base_size = std::max(
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
_min_storage_size(out_sizes_, out_strides_, out_effective_offset)
);
auto storage = at::zeros({base_size}, grad.options());
// prepare indices tensor if we will do index_add_ later
c10::optional<at::Tensor> flatten_full_indices;
if (inp_maybe_overlap || out_maybe_overlap) {
flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong));
}
// Step (2): use output geometry to scatter gradients into storage
if (out_maybe_overlap) {
auto out_indices = flatten_full_indices->as_strided(out_sizes_, out_strides_, out_effective_offset);
storage.index_add_(0, out_indices.reshape(-1), grad.reshape(-1));
} else {
// assume that new tensors have 0 storage offset
storage.as_strided(out_sizes_, out_strides_, out_effective_offset).copy_(grad);
}
// Step (3): if input tensor has overlapping memory, divide scattered gradient
// at storage[i] by the number of times i shows up in input geometry
if (inp_maybe_overlap) {
auto count = at::zeros_like(storage);
auto inp_indices = flatten_full_indices->as_strided(inp_sizes_, inp_strides_, inp_effective_offset).reshape(-1);
count.index_add_(0, inp_indices, at::ones({1}, grad.options()).expand_as(inp_indices));
storage.div_(count); // this will give nan outside visible range
}
// Step (4): return as_strided view of the storage tensor with input geometry
return storage.as_strided(inp_sizes, inp_strides, inp_effective_offset);
}
std::tuple<Tensor, Tensor> atan2_backward(const Tensor& grad, const Tensor& self, const Tensor& other, std::array<bool, 2> output_mask) {
auto recip = (self * self + other * other).reciprocal();
return std::tuple<Tensor,Tensor>{
output_mask[0] ? grad * other * recip : Tensor(),
output_mask[1] ? grad * -self * recip : Tensor() };
}
// TODO: Seriously consider writing the derivative formulas for
// each output separately; there is not all that much sharing
// of computation going on here.
std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
const Tensor & grad_grad_input,
const Tensor & grad_grad_weight,
const Tensor & grad_out,
const Tensor & input_,
const Tensor & weight_) {
auto input = input_.contiguous();
auto weight = weight_.contiguous();
// Zero-fill undefined grads (TODO: do this more efficiently)
auto ggI = grad_grad_input.defined() ? grad_grad_input.contiguous() : at::zeros_like(input);
auto ggW = grad_grad_weight.defined() ? grad_grad_weight.contiguous() : at::zeros_like(weight);
auto gO = grad_out.defined() ? grad_out.contiguous() : at::zeros_like(input);
auto positive_mask = (input > 0).type_as(ggI);
auto nonpositive_mask = (input <= 0).type_as(ggW);
// Explanation: Let input be i, weight be w, grad_output be gO.
// f(i, w) = i if i > 0
// = w * i if i <= 0
// gI = df/di * gO = gO if i > 0 gW = df/dw * gO = 0 if i > 0
// = gO * w if i <= 0 = gO * i if i <= 0
// The rest is taking derivatives of these wrt i, w, gO and summing/expanding properly.
if (weight.numel() == 1) {
// from PReLU.forward: num_parameters == 0 is used indicate that a
// single weight is shared among all input channels.
// this is a little tricky because PReLU currently doesn't take a shape so the weight may be
// 1-d when the input is a scalar (and there isn't a good Parameter API for that anyway until Variable
// and tensor are merged). So, use weight and ggW as 0-dim in this case.
bool scalar_input_1d_weight = (positive_mask.dim() == 0 && weight.dim() == 1);
auto weight_maybe_squeeze = scalar_input_1d_weight ? weight.squeeze() : weight;
auto ggW_maybe_squeeze = scalar_input_1d_weight ? ggW.squeeze() : ggW;
auto mask = positive_mask + nonpositive_mask * weight_maybe_squeeze.expand_as(input);
auto ggO = ggI * mask + ggW_maybe_squeeze.expand_as(gO) * (nonpositive_mask * input);
return std::tuple<Tensor, Tensor, Tensor>(
ggO,
ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask,
(ggI * gO * nonpositive_mask).sum().expand_as(weight)
);
} else {
// Expand ggW to match size of ggI; a simple expand doesn't work because
// ggW is the size of the input channel (dim==1 unless there is only 1 dimension). For example,
// let ggI be size (3,4,5,6,7) and ggW be size (4). Then we unsqueeze ggW to be size (4,1,1,1)
// so the expand succeeds.
auto dims_to_unsqueeze = std::max<int64_t>(input.dim() - 2, 0);
auto ggW_expanded = ggW;
for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
ggW_expanded = ggW_expanded.unsqueeze(1);
}
ggW_expanded = ggW_expanded.expand_as(ggI);
auto gI = ggW_expanded * gO * nonpositive_mask;
auto gW = ggI * gO * nonpositive_mask;
if (input.dim() > 1) {
gW = gW.sum(0);
}
while (gW.dim() > 1) {
gW = gW.sum(1);
}
Tensor ggO;
if (gO.requires_grad()) {
// expand weight as input as in ggW/ggI above
auto weight_expanded = weight;
for (int64_t i = 0; i < dims_to_unsqueeze; i++) {
weight_expanded = weight_expanded.unsqueeze(1);
}
weight_expanded = weight_expanded.expand_as(input);
auto mask = positive_mask + nonpositive_mask * weight_expanded;
ggO = ggI * mask + ggW_expanded * nonpositive_mask * input;
}
return std::tuple<Tensor,Tensor,Tensor>{ggO, gI, gW};
}
}
// https://j-towns.github.io/papers/svd-derivative.pdf
//
// This makes no assumption on the signs of sigma.
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, bool compute_uv, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
AT_CHECK(compute_uv,
"svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
"and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");
auto m = self.size(0);
auto n = self.size(1);
auto k = sigma.size(0);
auto gsigma = grads[1];
auto u = raw_u;
auto v = raw_v;
auto gu = grads[0];
auto gv = grads[2];
if (!some) {
// We ignore the free subspace here because possible base vectors cancel
// each other, e.g., both -v and +v are valid base for a dimension.
// Don't assume behavior of any particular implementation of svd.
u = raw_u.narrow(1, 0, k);
v = raw_v.narrow(1, 0, k);
if (gu.defined()) {
gu = gu.narrow(1, 0, k);
}
if (gv.defined()) {
gv = gv.narrow(1, 0, k);
}
}
auto vt = v.t();
Tensor sigma_term;
if (gsigma.defined()) {
sigma_term = u.mm(gsigma.diag()).mm(vt);
} else {
sigma_term = at::zeros({1}, self.options()).expand_as(self);
}
// in case that there are no gu and gv, we can avoid the series of kernel
// calls below
if (!gv.defined() && !gu.defined()) {
return sigma_term;
}
auto ut = u.t();
auto im = eye(m, self.options());
auto in = eye(n, self.options());
auto sigma_mat = sigma.diag();
auto sigma_mat_inv = sigma.pow(-1).diag();
auto sigma_expanded_sq = sigma.pow(2).expand_as(sigma_mat);
auto F = sigma_expanded_sq - sigma_expanded_sq.t();
// The following two lines invert values of F, and fills the diagonal with 0s.
// Notice that F currently has 0s on diagonal. So we fill diagonal with +inf
// first to prevent nan from appearing in backward of this function.
F.diagonal().fill_(INFINITY);
F = F.pow(-1);
Tensor u_term, v_term;
if (gu.defined()) {
u_term = u.mm(F.mul(ut.mm(gu) - gu.t().mm(u))).mm(sigma_mat);
if (m > k) {
u_term = u_term + (im - u.mm(ut)).mm(gu).mm(sigma_mat_inv);
}
u_term = u_term.mm(vt);
} else {
u_term = at::zeros({1}, self.options()).expand_as(self);
}
if (gv.defined()) {
auto gvt = gv.t();
v_term = sigma_mat.mm(F.mul(vt.mm(gv) - gvt.mm(v))).mm(vt);
if (n > k) {
v_term = v_term + sigma_mat_inv.mm(gvt.mm(in - v.mm(vt)));
}
v_term = u.mm(v_term);
} else {
v_term = at::zeros({1}, self.options()).expand_as(self);
}
return u_term + sigma_term + v_term;
}
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, bool upper, const Tensor& lambda, const Tensor& v) {
AT_CHECK(eigenvectors,
"symeig_backward: Setting eigenvectors to false in torch.symeig doesn't compute eigenvectors ",
"and hence we cannot compute backward. Please use torch.symeig(eigenvectors=True)");
auto glambda = grads[0];
auto gv = grads[1];
auto vt = v.t();
Tensor result;
if (gv.defined()) {
Tensor F = lambda.unsqueeze(0).expand_as(self).clone();
F.sub_(at::unsqueeze(lambda, 1));
F.diagonal().fill_(INFINITY);
F.pow_(-1);
F.mul_(vt.mm(gv));
result = v.mm(F.mm(vt));
} else {
result = at::zeros_like(self);
}
if (glambda.defined()) {
result.add_((v * glambda).mm(vt));
}
if (upper) {
result = at::triu(result) + at::triu(result.t(), 1);
} else {
result = at::tril(result) + at::tril(result.t(), -1);
}
return result;
}
// Invertible case is derived from Jacobi's formula, and also can be found at:
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
auto det_val = det.item<double>();
if (det_val != 0 /* invertible */) {
return grad * det * self.inverse().t();
} else /* otherwise det = \prod(sigma) = 0, use svd */ {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
auto gsigma = prod_backward(grad, sigma, det);
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
}
}
Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& logdet) {
auto logdet_val = logdet.item<double>();
if (logdet_val != -INFINITY /* det != 0, invertible */) {
return grad * self.inverse().t();
} else /* otherwise det = \prod(sigma) = 0, use svd */ {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
// backward det = \sum log(sigma)
auto gsigma = grad.div(sigma);
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
}
}
Tensor slogdet_backward(const std::vector<torch::autograd::Variable> &grads,
const Tensor& self,
const Tensor& signdet, const Tensor& logabsdet) {
AT_ASSERTM(!grads[0].defined(), "slogdet's sign output should never have gradient");
auto signdet_val = signdet.item<double>();
if (signdet_val != 0 /* det != 0, invertible */) {
return grads[1] * self.inverse().t();
} else /* otherwise det = \prod(sigma) = 0, use svd */ {
Tensor u, sigma, v;
std::tie(u, sigma, v) = self.svd();
// sigma has all non-negative entries (also with at least one zero entry)
// so logabsdet = \sum log(abs(sigma))
// but det = 0, so backward logabsdet = \sum log(sigma)
auto gsigma = grads[1].div(sigma);
return svd_backward({{}, gsigma, {}}, self, true, true, u, sigma, v);
}
}
// Reference:
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// Sec. 2.3.1 Matrix inverse product
std::tuple<Tensor, Tensor> trtrs_backward(
const Tensor & grad_x, const Tensor & grad_m,
const Tensor & b, const Tensor & a, const Tensor & x,
const bool upper, const bool transpose, const bool unitriangular,
std::array<bool, 2> output_mask) {
Tensor grad_b, grad_a;
if (grad_x.defined()) {
grad_b = std::get<0>(grad_x.trtrs(a, upper, !transpose, unitriangular));
if (output_mask[1]) {
grad_a = transpose ? -x.mm(grad_b.t()) : -grad_b.mm(x.t());
if (upper) {
grad_a = grad_a.triu((int) unitriangular);
} else {
grad_a = grad_a.tril(-((int) unitriangular));
}
}
}
if (!grad_a.defined()) {
grad_a = at::zeros({1}, a.options()).expand_as(a);
}
if (!grad_b.defined()) {
grad_b = at::zeros({1}, b.options()).expand_as(b);
}
if (output_mask[1] && grad_m.defined()) {
grad_a = grad_a.add(grad_m);
}
return std::tuple<Tensor, Tensor>{grad_b, grad_a};
}
// Generally speaking, fft's backward is ifft.
Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim,
bool complex_input, bool complex_output,
bool inverse, IntList checked_signal_sizes,
bool normalized, bool onesided,
IntList output_sizes) {
Tensor gI;
if (!complex_input && complex_output) {
// Forward is R2C
// Do inverse C2C and project onto real plane because grad can be
// asymmetrical so C2R can't be used.
if (onesided) {
// Forward is R2C (onesided)
// Think of onesided R2C rfft as
// 1. view as complex numbers (fill complex dim with zeros)
// 2. C2C fft
// 3. discard half of results
// So backward is
// 1. fill the other half with zeros (with `zero_grad_shape` below)
// (C2C ifft only take twosided inputs so we need to fill here)
// 2. inverse C2C ifft
// 3. discard the complex dim
int64_t zero_length = checked_signal_sizes[signal_ndim - 1] - grad.size(signal_ndim);
auto complex_full_grad = grad;
if (zero_length > 0) {
std::vector<int64_t> zero_grad_shape(signal_ndim + 2);
zero_grad_shape[0] = self.size(0);
for (int64_t i = 1; i < signal_ndim; i++) {
zero_grad_shape[i] = checked_signal_sizes[i - 1];
}
zero_grad_shape[signal_ndim] = zero_length;
zero_grad_shape[signal_ndim + 1] = 2;
complex_full_grad = at::cat({ grad, at::zeros(zero_grad_shape, grad.options()) }, signal_ndim);
}
gI = _fft_with_size(complex_full_grad, signal_ndim,
/* complex_input */ true, /* complex_output */ true,
!inverse, checked_signal_sizes, normalized,
/* onesided */ false, complex_full_grad.sizes()).select(-1, 0);
} else {
gI = _fft_with_size(grad, signal_ndim, /* complex_input */ true,
/* complex_output */ true, !inverse,
checked_signal_sizes, normalized,
/* onesided */ false, grad.sizes()).select(-1, 0);
}
} else if (complex_input && !complex_output && onesided) {
// Forward is C2R (onesided)
// Think of onesided C2R irfft as
// 1. fill the other half by conjugate symmetry
// 2. inverse C2C ifft
// 3. discard the complex dimension
// So backward is
// 1. R2C rfft (essentially add dummy complex dimension, and dft)
// 2. accumulate gradient by conjugate symmetry
// since rfft results follow conjugate symmetry, we only need to
// double some entries from onesided rfft results, i.e., the ones with
// their reflected indices also landing out of the onesided range. So
// consider the index of last dim:
// i. idx = 0.
// Reflected to (N - 0) % N = 0. Not doubled.
// ii 0 < idx < floor(N/2) (last).
// N > N - idx > ceil(N/2)
// Reflected to ()
// iii. idx = floor(N/2) = N/2 (last) when N even.
// Reflected to (N - N/2) % N = N/2. Not doubled.
// iv. idx = floor(N/2) = (N-1)/2 (last) when N odd.
// Reflected to (N - (N-1)/2) % N = (N+1)/2. Doubled.
// Therefore, needs to double
// idx = 1, 2, ..., N/2 - 1 when N even
// idx = 1, 2, ..., (N-1)/2 when N odd
// that is
// idx = 1, 2, ..., N - (floor(N/2) + 1)
// = 1, 2, ..., N - onesided_length
gI = _fft_with_size(grad, signal_ndim, /* complex_input */ false,
/* complex_output */ true, /* inverse */ false,
checked_signal_sizes, normalized, /* onesided */ true,
self.sizes());
int64_t double_length = checked_signal_sizes[signal_ndim - 1] - self.size(signal_ndim);
if (double_length > 0) { // also covers case when signal size is zero
gI.narrow(signal_ndim, 1, double_length).mul_(2);
}
} else {
gI = _fft_with_size(grad, signal_ndim, complex_output, complex_input,
!inverse, checked_signal_sizes, normalized, onesided,
self.sizes());
}
if (normalized) {
// If normalized, backward is exactly calling fft with inversed argument as
// the forward because both are unitary.
return gI;
} else {
// If not normalized, in backward, we need to upscale or downscale gI basing
// on whether the forward is an inverse fft.
auto signal_numel = std::accumulate(checked_signal_sizes.begin(),
checked_signal_sizes.end(), 1, std::multiplies<int64_t>());
if (!inverse) {
return gI.mul_(static_cast<double>(signal_numel));
} else {
return gI.div_(static_cast<double>(signal_numel));
}
}
}
// Helper for batchnorm_double_backward
Tensor sum_exclude_dim1(const Tensor& to_sum, bool keepdim=true) {
auto r = to_sum.sum(0, keepdim);
int64_t start_point_exclusive = keepdim ? 1 : 0;
for (int64_t dim = r.dim() - 1; dim > start_point_exclusive; dim--) {
r = r.sum(dim, keepdim);
}
return r;
}
// Helper for batchnorm_double_backward
// similar to expand_as below, but doesn't do the expand_as; operates as if
// reductions were done with keepdim=True
Tensor unsqueeze_dim1(const Tensor& src, const Tensor& target) {
auto src_expanded = src;
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(1);
}
if (src_expanded.sizes().size() == target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(0);
}
return src_expanded;
}
// Helper for batchnorm_double_backward
// because gamma/ggG/ggB are 1-dimensional and represent dim==1, we can't
// do a straight expansion because it won't follow the broadcasting rules.
Tensor expand_as_dim1(const Tensor& src, const Tensor& target) {
auto src_expanded = src;
while (src_expanded.sizes().size() < target.sizes().size() - 1) {
src_expanded = src_expanded.unsqueeze(1);
}
return src_expanded.expand_as(target);
}
std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
const Tensor & input,
const Tensor & gamma,
const Tensor & ggI,
const Tensor & ggG,
const Tensor & ggB,
const Tensor & gO,
const Tensor & running_mean,
const Tensor & running_var,
bool training,
double eps,
const Tensor & save_mean,
const Tensor & save_invstd,
std::array<bool,3> output_mask) {
bool affine = gamma.defined();
// TODO: Do we have a ScalarOrTensor type? Would such a thing exist?
Tensor gamma_expanded;
Tensor ggG_expanded, ggB_expanded;
if (affine) {
gamma_expanded = expand_as_dim1(gamma, input);
if (ggG.defined()) {
ggG_expanded = expand_as_dim1(ggG, input);
}
if (ggB.defined()) {
ggB_expanded = expand_as_dim1(ggB, input);
}
} else {
gamma_expanded = at::ones({}, input.options());
}
// define some terms we will reuse
auto M = input.size(0);
for (auto s : input.sizes().slice(2)) {
M *= s;
}
// for half inputs, save_mean, save_invstd are float (ideally, we would cast
// everything else, but not now)
auto mu = unsqueeze_dim1(training ? save_mean.to(input.type().scalarType()) : running_mean, input);
auto input_sub_mu = input - mu;
auto sigma2_eps_neg_1_2 = unsqueeze_dim1(training ? save_invstd.to(input.type().scalarType())
: running_var.add(Scalar(eps)).pow(-0.5), input);
auto sigma2_eps_neg_1 = sigma2_eps_neg_1_2.pow(2);
auto sigma2_eps_neg_3_2 = sigma2_eps_neg_1_2.pow(3);
// calculate gI
auto input_mu_sigma2_neg_3_2 = input_sub_mu * sigma2_eps_neg_3_2;
auto gOinmu_sum = sum_exclude_dim1(gO * input_sub_mu);
auto gO_sum = sum_exclude_dim1(gO);
Tensor gI;
if (ggI.defined() && training) {
auto ggI_sum = sum_exclude_dim1(ggI);
auto ggIinmu_sum = sum_exclude_dim1(ggI * input_sub_mu);
auto all_sub = ((ggI_sum * gO_sum).div_(M)).sub_(sum_exclude_dim1(gO * ggI)).add_(
(sigma2_eps_neg_1 * gOinmu_sum * ggIinmu_sum).mul_(3. / M));
auto gI_0t = (input_mu_sigma2_neg_3_2 * all_sub).div_(M);
auto gI_1t = (ggIinmu_sum * sigma2_eps_neg_3_2).div_(M) * (gO_sum.div(M) - gO);
auto gI_2t = (gOinmu_sum * sigma2_eps_neg_3_2).div_(M) * (ggI_sum.div(M) - ggI);
gI = gamma_expanded * (gI_0t.add_(gI_1t).add_(gI_2t));
}
// add contribution of gamma term to gI
Tensor gI_G_term;
if (affine && ggG.defined()) {
if (training) {
auto t0 = gO * sigma2_eps_neg_1_2;
auto t1 = (sigma2_eps_neg_1_2 * gO_sum).div_(-M);
auto t2 = (input_mu_sigma2_neg_3_2 * sum_exclude_dim1(gO * input_sub_mu)).div_(-M);
gI_G_term = ggG_expanded * (t0.add_(t1).add_(t2));
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
} else {
gI_G_term = ggG_expanded * sigma2_eps_neg_1_2 * gO;
gI = gI.defined() ? gI.add_(gI_G_term) : gI_G_term;
}
}
// this is the first backward's grad_input
auto first_back_grad_input = [&](const Tensor& gO, const Tensor& gamma) -> Tensor {
auto h0 = (gamma * sigma2_eps_neg_1_2).div_(M);
auto h1 = (M * gO).sub_(sum_exclude_dim1(gO)).sub_(
input_sub_mu.mul(sigma2_eps_neg_1) * sum_exclude_dim1(gO * input_sub_mu));
return h0 * h1;
};
// calculate gG
Tensor gG;
if (affine && ggI.defined()) {
if (training) {
// gG is just the first backwards with the gamma term removed (then shaped properly)
gG = ggI * first_back_grad_input(gO, at::ones({}, sigma2_eps_neg_1_2.options()));
gG = sum_exclude_dim1(gG, false);
} else {
gG = sum_exclude_dim1(ggI * gO * sigma2_eps_neg_1_2, false);
}
}
// calculate ggO
Tensor ggO;
// contribution of input term
if (ggI.defined()) {
if (training) {
ggO = first_back_grad_input(ggI, gamma_expanded);
} else {
ggO = ggI * sigma2_eps_neg_1_2 * gamma_expanded;
}
}
if (ggG.defined()) {
auto ggO_G_term = ggG_expanded * input_sub_mu * sigma2_eps_neg_1_2;
ggO = ggO.defined() ? ggO.add_(ggO_G_term) : ggO_G_term;
}
if (ggB.defined()) {
auto ggO_B_term = ggB_expanded;
ggO = ggO.defined() ? ggO.add_(ggO_B_term) : ggO_B_term;
}
if (output_mask[0] && !ggO.defined()) ggO = at::zeros_like(gO);
if (output_mask[1] && !gG.defined()) {
AT_ASSERTM(affine, "gamma should always be defined when it requires grad");
gG = at::zeros_like(gamma);
}
if (output_mask[2] && !gI.defined()) gI = at::zeros_like(input);
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}
std::tuple<Tensor, Tensor, Tensor> _trilinear_backward(const Tensor& grad_out, const Tensor& i1, const Tensor& i2, const Tensor& i3,
IntList expand1, IntList expand2, IntList expand3,
IntList sumdim, int64_t unroll_dim, std::array<bool, 3> grad_mask) {
Tensor grad_i1, grad_i2, grad_i3;
if (grad_mask[0])
grad_i1 = at::_trilinear(grad_out, i2, i3, sumdim, expand2, expand3, expand1);
if (grad_mask[1])
grad_i2 = at::_trilinear(i1, grad_out, i3, expand1, sumdim, expand3, expand2);
if (grad_mask[2])
grad_i3 = at::_trilinear(i1, i2, grad_out, expand1, expand2, sumdim, expand3);
return std::tuple<Tensor, Tensor, Tensor>(grad_i1, grad_i2, grad_i3);
}
Tensor log1p_backward(const Tensor& grad, const Tensor& self) {
if (self.is_sparse()) {
AT_ERROR(
"log1p of a sparse tensor is made to be non-differentiable since ",
"local gradient of zero is 1 / (0 + 1) = 1 and it makes the tensor dense. ",
"Use a different mathematical operation which preserves sparsity of gradients, ",
"or report a bug if you think this is an error.");
}
return grad / (self + 1);
}
Tensor sparse_constructor_values_backward(const Tensor& sparse_grad_out, const Tensor& indices, IntList values_shape) {
// TODO: improve this backward by writing a kernel (maybe)
auto dense_grad = sparse_grad_out.is_sparse() ? sparse_grad_out.to_dense() : sparse_grad_out;
auto full_size = sparse_grad_out.sizes();
auto flattened_grad_shape = values_shape.vec();
flattened_grad_shape[0] = at::prod_intlist(full_size.slice(0, indices.size(0)));
auto flattened_dense_grad = dense_grad.view(flattened_grad_shape);
auto flattened_indices = at::sparse::flatten_indices(indices, full_size);
return flattened_dense_grad.index_select(0, flattened_indices);
}
// Because the backward of pad(input, pads) is just pad(grad_output, [-p for p in pads])
Tensor constant_pad_nd_backward(const Tensor& grad, IntList pad) {
auto negated_pad = pad.vec();
std::transform(negated_pad.cbegin(), negated_pad.cend(), negated_pad.begin(), std::negate<int64_t>());
return at::constant_pad_nd(grad, negated_pad, 0);
}
} // anonymous namespace
${autograd_function_definitions}
}}} // namespace torch::autograd::generated