Modify Cholesky derivative (#19116)

Summary:
The derivative of the Cholesky decomposition was previously a triangular matrix.

Changelog:
- Modify the derivative of Cholesky from a triangular matrix to symmetric matrix
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19116

Differential Revision: D14935470

Pulled By: ezyang

fbshipit-source-id: 1c1c76b478c6b99e4e16624682842cb632e8e8b9
This commit is contained in:
vishwakftw
2019-04-15 11:53:44 -07:00
committed by Facebook Github Bot
parent 991279dc7d
commit 3403cb857b
3 changed files with 46 additions and 41 deletions

View File

@ -17,7 +17,7 @@ from torch.autograd.profiler import profile
from torch.utils.checkpoint import checkpoint
from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
suppress_warnings, skipIfRocm,
load_tests)
load_tests, random_symmetric_pd_matrix)
from common_cuda import TEST_CUDA
from torch.autograd import Variable, Function, detect_anomaly
from torch.autograd.function import InplaceFunction
@ -2162,19 +2162,19 @@ class TestAutograd(TestCase):
@skipIfNoLapack
def test_cholesky(self):
def func(root):
def func(root, upper):
x = torch.matmul(root, root.transpose(-1, -2)) + 1e-05
return torch.cholesky(x, upper)
def run_test(upper, dims):
root = torch.rand(*dims)
indices = torch.ones(dims[-1], dims[-1], dtype=torch.uint8).tril()
indices = indices.expand_as(root)
root[indices] = 0
root.requires_grad_()
root = torch.rand(*dims, requires_grad=True)
gradcheck(func, [root])
gradgradcheck(func, [root])
gradcheck(func, [root, upper])
gradgradcheck(func, [root, upper])
root = random_symmetric_pd_matrix(dims[-1], *dims[:-2]).requires_grad_()
chol = root.cholesky().sum().backward()
self.assertEqual(root.grad, root.grad.transpose(-1, -2)) # Check the gradient is symmetric
for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]):
run_test(upper, dims)

View File

@ -1726,10 +1726,8 @@ class TestDistributions(TestCase):
# construct batch of PSD covariances
tmp = torch.randn(6, 5, 3, 10)
cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
prec_batched = [C.inverse() for C in cov_batched.view((-1, 3, 3))]
prec_batched = torch.stack(prec_batched).view(cov_batched.shape)
scale_tril_batched = [torch.cholesky(C, upper=False) for C in cov_batched.view((-1, 3, 3))]
scale_tril_batched = torch.stack(scale_tril_batched).view(cov_batched.shape)
prec_batched = cov_batched.inverse()
scale_tril_batched = cov_batched.cholesky(upper=False)
# ensure that sample, batch, event shapes all handled correctly
self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
@ -1750,13 +1748,26 @@ class TestDistributions(TestCase):
self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
# check gradients
self._gradcheck_log_prob(MultivariateNormal, (mean, cov))
self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov))
self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov_batched))
self._gradcheck_log_prob(MultivariateNormal, (mean, None, prec))
self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, prec_batched))
self._gradcheck_log_prob(MultivariateNormal, (mean, None, None, scale_tril))
self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, None, scale_tril_batched))
# We write a custom gradcheck function to maintain the symmetry
# of the perturbed covariances and their inverses (precision)
def multivariate_normal_log_prob_gradcheck(mean, covariance=None, precision=None, scale_tril=None):
mvn_samples = MultivariateNormal(mean, covariance, precision, scale_tril).sample().requires_grad_()
def gradcheck_func(samples, mu, sigma, prec, scale_tril):
if sigma is not None:
sigma = 0.5 * (sigma + sigma.transpose(-1, -2)) # Ensure symmetry of covariance
if prec is not None:
prec = 0.5 * (prec + prec.transpose(-1, -2)) # Ensure symmetry of precision
return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples)
gradcheck(gradcheck_func, (mvn_samples, mean, covariance, precision, scale_tril), raise_exception=True)
multivariate_normal_log_prob_gradcheck(mean, cov)
multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov)
multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov_batched)
multivariate_normal_log_prob_gradcheck(mean, None, prec)
multivariate_normal_log_prob_gradcheck(mean_no_batch, None, prec_batched)
multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril)
multivariate_normal_log_prob_gradcheck(mean_no_batch, None, None, scale_tril_batched)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_multivariate_normal_log_prob(self):

View File

@ -694,31 +694,25 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra
Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
// cf. Iain Murray (2016); arXiv 1602.07527
// This gradient is symmetric, and not triangular.
// Cholesky additionally assumes that the input is symmetric, which is a subspace of
// R^{n x n}, and hence the derivative is not well-defined for off-diagonal
// elements. We resolve this by taking the gradient of the functionally independent
// elements of the matrix (i.e., the lower triangular portion of the input) and then
// reflect it on the upper triangular portion, thereby symmetrizing the gradient of
// the cholesky operation. The motivation behind this choice is that symmetric gradient
// leads to stable gradient updates, and retains symmetry of the updated matrix if it
// were updated by a gradient based algorithm.
if (upper) {
grad = grad.transpose(-1, -2);
} else {
L = L.transpose(-1, -2);
grad = grad.transpose(-1, -2);
}
auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false));
auto phi = at::matmul(L.transpose(-1, -2), grad);
phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);
auto phi = [](const Tensor & A) -> Tensor {
auto B = A.tril();
B = B - 0.5 * at::diag_embed(B.diagonal(0, -1, -2), 0, -2, -1);
return B;
};
// make sure not to double-count variation, since
// only half of output matrix is unique
auto Lbar = grad.tril();
auto P = phi(at::matmul(L, Lbar));
Tensor S;
std::tie(S, std::ignore) = at::solve(P + P.transpose(-1, -2), L);
std::tie(S, std::ignore) = at::solve(S.transpose(-1, -2), L);
S = phi(S);
if (upper) {
S = S.transpose(-1, -2);
}
return S;
auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2), phi), L_inverse);
return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5); // Symmetrizing the gradient
}
Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,