mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Modify Cholesky derivative (#19116)
Summary: The derivative of the Cholesky decomposition was previously a triangular matrix. Changelog: - Modify the derivative of Cholesky from a triangular matrix to symmetric matrix Pull Request resolved: https://github.com/pytorch/pytorch/pull/19116 Differential Revision: D14935470 Pulled By: ezyang fbshipit-source-id: 1c1c76b478c6b99e4e16624682842cb632e8e8b9
This commit is contained in:
committed by
Facebook Github Bot
parent
991279dc7d
commit
3403cb857b
@ -17,7 +17,7 @@ from torch.autograd.profiler import profile
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack,
|
||||
suppress_warnings, skipIfRocm,
|
||||
load_tests)
|
||||
load_tests, random_symmetric_pd_matrix)
|
||||
from common_cuda import TEST_CUDA
|
||||
from torch.autograd import Variable, Function, detect_anomaly
|
||||
from torch.autograd.function import InplaceFunction
|
||||
@ -2162,19 +2162,19 @@ class TestAutograd(TestCase):
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_cholesky(self):
|
||||
def func(root):
|
||||
def func(root, upper):
|
||||
x = torch.matmul(root, root.transpose(-1, -2)) + 1e-05
|
||||
return torch.cholesky(x, upper)
|
||||
|
||||
def run_test(upper, dims):
|
||||
root = torch.rand(*dims)
|
||||
indices = torch.ones(dims[-1], dims[-1], dtype=torch.uint8).tril()
|
||||
indices = indices.expand_as(root)
|
||||
root[indices] = 0
|
||||
root.requires_grad_()
|
||||
root = torch.rand(*dims, requires_grad=True)
|
||||
|
||||
gradcheck(func, [root])
|
||||
gradgradcheck(func, [root])
|
||||
gradcheck(func, [root, upper])
|
||||
gradgradcheck(func, [root, upper])
|
||||
|
||||
root = random_symmetric_pd_matrix(dims[-1], *dims[:-2]).requires_grad_()
|
||||
chol = root.cholesky().sum().backward()
|
||||
self.assertEqual(root.grad, root.grad.transpose(-1, -2)) # Check the gradient is symmetric
|
||||
|
||||
for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]):
|
||||
run_test(upper, dims)
|
||||
|
@ -1726,10 +1726,8 @@ class TestDistributions(TestCase):
|
||||
# construct batch of PSD covariances
|
||||
tmp = torch.randn(6, 5, 3, 10)
|
||||
cov_batched = (tmp.unsqueeze(-2) * tmp.unsqueeze(-3)).mean(-1).requires_grad_()
|
||||
prec_batched = [C.inverse() for C in cov_batched.view((-1, 3, 3))]
|
||||
prec_batched = torch.stack(prec_batched).view(cov_batched.shape)
|
||||
scale_tril_batched = [torch.cholesky(C, upper=False) for C in cov_batched.view((-1, 3, 3))]
|
||||
scale_tril_batched = torch.stack(scale_tril_batched).view(cov_batched.shape)
|
||||
prec_batched = cov_batched.inverse()
|
||||
scale_tril_batched = cov_batched.cholesky(upper=False)
|
||||
|
||||
# ensure that sample, batch, event shapes all handled correctly
|
||||
self.assertEqual(MultivariateNormal(mean, cov).sample().size(), (5, 3))
|
||||
@ -1750,13 +1748,26 @@ class TestDistributions(TestCase):
|
||||
self.assertEqual(MultivariateNormal(mean, scale_tril=scale_tril_batched).sample((2, 7)).size(), (2, 7, 6, 5, 3))
|
||||
|
||||
# check gradients
|
||||
self._gradcheck_log_prob(MultivariateNormal, (mean, cov))
|
||||
self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov))
|
||||
self._gradcheck_log_prob(MultivariateNormal, (mean_multi_batch, cov_batched))
|
||||
self._gradcheck_log_prob(MultivariateNormal, (mean, None, prec))
|
||||
self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, prec_batched))
|
||||
self._gradcheck_log_prob(MultivariateNormal, (mean, None, None, scale_tril))
|
||||
self._gradcheck_log_prob(MultivariateNormal, (mean_no_batch, None, None, scale_tril_batched))
|
||||
# We write a custom gradcheck function to maintain the symmetry
|
||||
# of the perturbed covariances and their inverses (precision)
|
||||
def multivariate_normal_log_prob_gradcheck(mean, covariance=None, precision=None, scale_tril=None):
|
||||
mvn_samples = MultivariateNormal(mean, covariance, precision, scale_tril).sample().requires_grad_()
|
||||
|
||||
def gradcheck_func(samples, mu, sigma, prec, scale_tril):
|
||||
if sigma is not None:
|
||||
sigma = 0.5 * (sigma + sigma.transpose(-1, -2)) # Ensure symmetry of covariance
|
||||
if prec is not None:
|
||||
prec = 0.5 * (prec + prec.transpose(-1, -2)) # Ensure symmetry of precision
|
||||
return MultivariateNormal(mu, sigma, prec, scale_tril).log_prob(samples)
|
||||
gradcheck(gradcheck_func, (mvn_samples, mean, covariance, precision, scale_tril), raise_exception=True)
|
||||
|
||||
multivariate_normal_log_prob_gradcheck(mean, cov)
|
||||
multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov)
|
||||
multivariate_normal_log_prob_gradcheck(mean_multi_batch, cov_batched)
|
||||
multivariate_normal_log_prob_gradcheck(mean, None, prec)
|
||||
multivariate_normal_log_prob_gradcheck(mean_no_batch, None, prec_batched)
|
||||
multivariate_normal_log_prob_gradcheck(mean, None, None, scale_tril)
|
||||
multivariate_normal_log_prob_gradcheck(mean_no_batch, None, None, scale_tril_batched)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_multivariate_normal_log_prob(self):
|
||||
|
@ -694,31 +694,25 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra
|
||||
|
||||
Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
|
||||
// cf. Iain Murray (2016); arXiv 1602.07527
|
||||
// This gradient is symmetric, and not triangular.
|
||||
// Cholesky additionally assumes that the input is symmetric, which is a subspace of
|
||||
// R^{n x n}, and hence the derivative is not well-defined for off-diagonal
|
||||
// elements. We resolve this by taking the gradient of the functionally independent
|
||||
// elements of the matrix (i.e., the lower triangular portion of the input) and then
|
||||
// reflect it on the upper triangular portion, thereby symmetrizing the gradient of
|
||||
// the cholesky operation. The motivation behind this choice is that symmetric gradient
|
||||
// leads to stable gradient updates, and retains symmetry of the updated matrix if it
|
||||
// were updated by a gradient based algorithm.
|
||||
if (upper) {
|
||||
grad = grad.transpose(-1, -2);
|
||||
} else {
|
||||
L = L.transpose(-1, -2);
|
||||
grad = grad.transpose(-1, -2);
|
||||
}
|
||||
auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false));
|
||||
auto phi = at::matmul(L.transpose(-1, -2), grad);
|
||||
phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);
|
||||
|
||||
auto phi = [](const Tensor & A) -> Tensor {
|
||||
auto B = A.tril();
|
||||
B = B - 0.5 * at::diag_embed(B.diagonal(0, -1, -2), 0, -2, -1);
|
||||
return B;
|
||||
};
|
||||
|
||||
// make sure not to double-count variation, since
|
||||
// only half of output matrix is unique
|
||||
auto Lbar = grad.tril();
|
||||
|
||||
auto P = phi(at::matmul(L, Lbar));
|
||||
Tensor S;
|
||||
std::tie(S, std::ignore) = at::solve(P + P.transpose(-1, -2), L);
|
||||
std::tie(S, std::ignore) = at::solve(S.transpose(-1, -2), L);
|
||||
S = phi(S);
|
||||
if (upper) {
|
||||
S = S.transpose(-1, -2);
|
||||
}
|
||||
return S;
|
||||
auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2), phi), L_inverse);
|
||||
return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5); // Symmetrizing the gradient
|
||||
}
|
||||
|
||||
Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
|
||||
|
Reference in New Issue
Block a user