mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: According to pytorch/rfcs#3 From the goals in the RFC: 1. Support subclassing `torch.Tensor` in Python (done here) 2. Preserve `torch.Tensor` subclasses when calling `torch` functions on them (done here) 3. Use the PyTorch API with `torch.Tensor`-like objects that are _not_ `torch.Tensor` subclasses (done in https://github.com/pytorch/pytorch/issues/30730) 4. Preserve `torch.Tensor` subclasses when calling `torch.Tensor` methods. (done here) 5. Propagating subclass instances correctly also with operators, using views/slices/indexing/etc. (done here) 6. Preserve subclass attributes when using methods or views/slices/indexing. (done here) 7. A way to insert code that operates on both functions and methods uniformly (so we can write a single function that overrides all operators). (done here) 8. The ability to give external libraries a way to also define functions/methods that follow the `__torch_function__` protocol. (will be addressed in a separate PR) This PR makes the following changes: 1. Adds the `self` argument to the arg parser. 2. Dispatches on `self` as well if `self` is not `nullptr`. 3. Adds a `torch._C.DisableTorchFunction` context manager to disable `__torch_function__`. 4. Adds a `torch::torch_function_enabled()` and `torch._C._torch_function_enabled()` to check the state of `__torch_function__`. 5. Dispatches all `torch._C.TensorBase` and `torch.Tensor` methods via `__torch_function__`. TODO: - [x] Sequence Methods - [x] Docs - [x] Tests Closes https://github.com/pytorch/pytorch/issues/28361 Benchmarks in https://github.com/pytorch/pytorch/pull/37091#issuecomment-633657778 Pull Request resolved: https://github.com/pytorch/pytorch/pull/37091 Reviewed By: ngimel Differential Revision: D22765678 Pulled By: ezyang fbshipit-source-id: 53f8aa17ddb8b1108c0997f6a7aa13cb5be73de0
743 lines
28 KiB
Python
743 lines
28 KiB
Python
"""Locally Optimal Block Preconditioned Conjugate Gradient methods.
|
|
"""
|
|
# Author: Pearu Peterson
|
|
# Created: February 2020
|
|
|
|
from typing import Dict, Tuple, Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from . import _linalg_utils as _utils
|
|
from .overrides import has_torch_function, handle_torch_function
|
|
|
|
|
|
__all__ = ['lobpcg']
|
|
|
|
|
|
def lobpcg(A, # type: Tensor
|
|
k=None, # type: Optional[int]
|
|
B=None, # type: Optional[Tensor]
|
|
X=None, # type: Optional[Tensor]
|
|
n=None, # type: Optional[int]
|
|
iK=None, # type: Optional[Tensor]
|
|
niter=None, # type: Optional[int]
|
|
tol=None, # type: Optional[float]
|
|
largest=None, # type: Optional[bool]
|
|
method=None, # type: Optional[str]
|
|
tracker=None, # type: Optional[None]
|
|
ortho_iparams=None, # type: Optional[Dict[str, int]]
|
|
ortho_fparams=None, # type: Optional[Dict[str, float]]
|
|
ortho_bparams=None, # type: Optional[Dict[str, bool]]
|
|
):
|
|
# type: (...) -> Tuple[Tensor, Tensor]
|
|
|
|
"""Find the k largest (or smallest) eigenvalues and the corresponding
|
|
eigenvectors of a symmetric positive defined generalized
|
|
eigenvalue problem using matrix-free LOBPCG methods.
|
|
|
|
This function is a front-end to the following LOBPCG algorithms
|
|
selectable via `method` argument:
|
|
|
|
`method="basic"` - the LOBPCG method introduced by Andrew
|
|
Knyazev, see [Knyazev2001]. A less robust method, may fail when
|
|
Cholesky is applied to singular input.
|
|
|
|
`method="ortho"` - the LOBPCG method with orthogonal basis
|
|
selection [StathopoulosEtal2002]. A robust method.
|
|
|
|
Supported inputs are dense, sparse, and batches of dense matrices.
|
|
|
|
.. note:: In general, the basic method spends least time per
|
|
iteration. However, the robust methods converge much faster and
|
|
are more stable. So, the usage of the basic method is generally
|
|
not recommended but there exist cases where the usage of the
|
|
basic method may be preferred.
|
|
|
|
Arguments:
|
|
|
|
A (Tensor): the input tensor of size :math:`(*, m, m)`
|
|
|
|
B (Tensor, optional): the input tensor of size :math:`(*, m,
|
|
m)`. When not specified, `B` is interpereted as
|
|
identity matrix.
|
|
|
|
X (tensor, optional): the input tensor of size :math:`(*, m, n)`
|
|
where `k <= n <= m`. When specified, it is used as
|
|
initial approximation of eigenvectors. X must be a
|
|
dense tensor.
|
|
|
|
iK (tensor, optional): the input tensor of size :math:`(*, m,
|
|
m)`. When specified, it will be used as preconditioner.
|
|
|
|
k (integer, optional): the number of requested
|
|
eigenpairs. Default is the number of :math:`X`
|
|
columns (when specified) or `1`.
|
|
|
|
n (integer, optional): if :math:`X` is not specified then `n`
|
|
specifies the size of the generated random
|
|
approximation of eigenvectors. Default value for `n`
|
|
is `k`. If :math:`X` is specified, the value of `n`
|
|
(when specified) must be the number of :math:`X`
|
|
columns.
|
|
|
|
tol (float, optional): residual tolerance for stopping
|
|
criterion. Default is `feps ** 0.5` where `feps` is
|
|
smallest non-zero floating-point number of the given
|
|
input tensor `A` data type.
|
|
|
|
largest (bool, optional): when True, solve the eigenproblem for
|
|
the largest eigenvalues. Otherwise, solve the
|
|
eigenproblem for smallest eigenvalues. Default is
|
|
`True`.
|
|
|
|
method (str, optional): select LOBPCG method. See the
|
|
description of the function above. Default is
|
|
"ortho".
|
|
|
|
niter (int, optional): maximum number of iterations. When
|
|
reached, the iteration process is hard-stopped and
|
|
the current approximation of eigenpairs is returned.
|
|
For infinite iteration but until convergence criteria
|
|
is met, use `-1`.
|
|
|
|
tracker (callable, optional) : a function for tracing the
|
|
iteration process. When specified, it is called at
|
|
each iteration step with LOBPCG instance as an
|
|
argument. The LOBPCG instance holds the full state of
|
|
the iteration process in the following attributes:
|
|
|
|
`iparams`, `fparams`, `bparams` - dictionaries of
|
|
integer, float, and boolean valued input
|
|
parameters, respectively
|
|
|
|
`ivars`, `fvars`, `bvars`, `tvars` - dictionaries
|
|
of integer, float, boolean, and Tensor valued
|
|
iteration variables, respectively.
|
|
|
|
`A`, `B`, `iK` - input Tensor arguments.
|
|
|
|
`E`, `X`, `S`, `R` - iteration Tensor variables.
|
|
|
|
For instance:
|
|
|
|
`ivars["istep"]` - the current iteration step
|
|
`X` - the current approximation of eigenvectors
|
|
`E` - the current approximation of eigenvalues
|
|
`R` - the current residual
|
|
`ivars["converged_count"]` - the current number of converged eigenpairs
|
|
`tvars["rerr"]` - the current state of convergence criteria
|
|
|
|
Note that when `tracker` stores Tensor objects from
|
|
the LOBPCG instance, it must make copies of these.
|
|
|
|
If `tracker` sets `bvars["force_stop"] = True`, the
|
|
iteration process will be hard-stopped.
|
|
|
|
ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
|
|
various parameters to LOBPCG algorithm when using
|
|
`method="ortho"`.
|
|
|
|
Returns:
|
|
|
|
E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
|
|
|
|
X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
|
|
|
|
References:
|
|
|
|
[Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
|
|
Preconditioned Eigensolver: Locally Optimal Block Preconditioned
|
|
Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
|
|
517-541. (25 pages)
|
|
https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
|
|
|
|
[StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
|
|
Wu. (2002) A Block Orthogonalization Procedure with Constant
|
|
Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
|
|
2165-2182. (18 pages)
|
|
https://epubs.siam.org/doi/10.1137/S1064827500370883
|
|
|
|
[DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
|
|
Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
|
|
SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
|
|
https://epubs.siam.org/doi/abs/10.1137/17M1129830
|
|
|
|
"""
|
|
|
|
if not torch.jit.is_scripting():
|
|
tensor_ops = (A, B, X, iK)
|
|
if (not set(map(type, tensor_ops)).issubset((torch.Tensor, type(None))) and has_torch_function(tensor_ops)):
|
|
return handle_torch_function(
|
|
lobpcg, tensor_ops, A, k=k,
|
|
B=B, X=X, n=n, iK=iK, niter=niter, tol=tol,
|
|
largest=largest, method=method, tracker=tracker,
|
|
ortho_iparams=ortho_iparams,
|
|
ortho_fparams=ortho_fparams,
|
|
ortho_bparams=ortho_bparams)
|
|
|
|
# A must be square:
|
|
assert A.shape[-2] == A.shape[-1], A.shape
|
|
if B is not None:
|
|
# A and B must have the same shapes:
|
|
assert A.shape == B.shape, (A.shape, B.shape)
|
|
|
|
dtype = _utils.get_floating_dtype(A)
|
|
device = A.device
|
|
if tol is None:
|
|
feps = {torch.float32: 1.2e-07,
|
|
torch.float64: 2.23e-16}[dtype]
|
|
tol = feps ** 0.5
|
|
|
|
m = A.shape[-1]
|
|
k = (1 if X is None else X.shape[-1]) if k is None else k
|
|
n = (k if n is None else n) if X is None else X.shape[-1]
|
|
|
|
if (m < 3 * n):
|
|
raise ValueError(
|
|
'LPBPCG algorithm is not applicable when the number of A rows (={})'
|
|
' is smaller than 3 x the number of requested eigenpairs (={})'
|
|
.format(m, n))
|
|
|
|
method = 'ortho' if method is None else method
|
|
|
|
iparams = {
|
|
'm': m,
|
|
'n': n,
|
|
'k': k,
|
|
'niter': 1000 if niter is None else niter,
|
|
}
|
|
|
|
fparams = {
|
|
'tol': tol,
|
|
}
|
|
|
|
bparams = {
|
|
'largest': True if largest is None else largest
|
|
}
|
|
|
|
if method == 'ortho':
|
|
if ortho_iparams is not None:
|
|
iparams.update(ortho_iparams)
|
|
if ortho_fparams is not None:
|
|
fparams.update(ortho_fparams)
|
|
if ortho_bparams is not None:
|
|
bparams.update(ortho_bparams)
|
|
iparams['ortho_i_max'] = iparams.get('ortho_i_max', 3)
|
|
iparams['ortho_j_max'] = iparams.get('ortho_j_max', 3)
|
|
fparams['ortho_tol'] = fparams.get('ortho_tol', tol)
|
|
fparams['ortho_tol_drop'] = fparams.get('ortho_tol_drop', tol)
|
|
fparams['ortho_tol_replace'] = fparams.get('ortho_tol_replace', tol)
|
|
bparams['ortho_use_drop'] = bparams.get('ortho_use_drop', False)
|
|
|
|
if not torch.jit.is_scripting():
|
|
LOBPCG.call_tracker = LOBPCG_call_tracker
|
|
|
|
if len(A.shape) > 2:
|
|
N = int(torch.prod(torch.tensor(A.shape[:-2])))
|
|
bA = A.reshape((N,) + A.shape[-2:])
|
|
bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
|
|
bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
|
|
bE = torch.empty((N, k), dtype=dtype, device=device)
|
|
bXret = torch.empty((N, m, k), dtype=dtype, device=device)
|
|
|
|
for i in range(N):
|
|
A_ = bA[i]
|
|
B_ = bB[i] if bB is not None else None
|
|
X_ = torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
|
|
assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
|
|
iparams['batch_index'] = i
|
|
worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
|
|
worker.run()
|
|
bE[i] = worker.E[:k]
|
|
bXret[i] = worker.X[:, :k]
|
|
|
|
if not torch.jit.is_scripting():
|
|
LOBPCG.call_tracker = LOBPCG_call_tracker_orig
|
|
|
|
return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
|
|
|
|
X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
|
|
assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
|
|
|
|
worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
|
|
|
|
worker.run()
|
|
|
|
if not torch.jit.is_scripting():
|
|
LOBPCG.call_tracker = LOBPCG_call_tracker_orig
|
|
|
|
return worker.E[:k], worker.X[:, :k]
|
|
|
|
|
|
class LOBPCG(object):
|
|
"""Worker class of LOBPCG methods.
|
|
"""
|
|
|
|
def __init__(self,
|
|
A, # type: Optional[Tensor]
|
|
B, # type: Optional[Tensor]
|
|
X, # type: Tensor
|
|
iK, # type: Optional[Tensor]
|
|
iparams, # type: Dict[str, int]
|
|
fparams, # type: Dict[str, float]
|
|
bparams, # type: Dict[str, bool]
|
|
method, # type: str
|
|
tracker # type: Optional[None]
|
|
):
|
|
# type: (...) -> None
|
|
|
|
# constant parameters
|
|
self.A = A
|
|
self.B = B
|
|
self.iK = iK
|
|
self.iparams = iparams
|
|
self.fparams = fparams
|
|
self.bparams = bparams
|
|
self.method = method
|
|
self.tracker = tracker
|
|
m = iparams['m']
|
|
n = iparams['n']
|
|
|
|
# variable parameters
|
|
self.X = X
|
|
self.E = torch.zeros((n, ), dtype=X.dtype, device=X.device)
|
|
self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
|
|
self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
|
|
self.tvars = {} # type: Dict[str, Tensor]
|
|
self.ivars = {'istep': 0} # type: Dict[str, int]
|
|
self.fvars = {'_': 0.0} # type: Dict[str, float]
|
|
self.bvars = {'_': False} # type: Dict[str, bool]
|
|
|
|
def __str__(self):
|
|
lines = ['LOPBCG:']
|
|
lines += [' iparams={}'.format(self.iparams)]
|
|
lines += [' fparams={}'.format(self.fparams)]
|
|
lines += [' bparams={}'.format(self.bparams)]
|
|
lines += [' ivars={}'.format(self.ivars)]
|
|
lines += [' fvars={}'.format(self.fvars)]
|
|
lines += [' bvars={}'.format(self.bvars)]
|
|
lines += [' tvars={}'.format(self.tvars)]
|
|
lines += [' A={}'.format(self.A)]
|
|
lines += [' B={}'.format(self.B)]
|
|
lines += [' iK={}'.format(self.iK)]
|
|
lines += [' X={}'.format(self.X)]
|
|
lines += [' E={}'.format(self.E)]
|
|
r = ''
|
|
for line in lines:
|
|
r += line + '\n'
|
|
return r
|
|
|
|
def update(self):
|
|
"""Set and update iteration variables.
|
|
"""
|
|
if self.ivars['istep'] == 0:
|
|
X_norm = float(torch.norm(self.X))
|
|
iX_norm = X_norm ** -1
|
|
A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
|
|
B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
|
|
self.fvars['X_norm'] = X_norm
|
|
self.fvars['A_norm'] = A_norm
|
|
self.fvars['B_norm'] = B_norm
|
|
self.ivars['iterations_left'] = self.iparams['niter']
|
|
self.ivars['converged_count'] = 0
|
|
self.ivars['converged_end'] = 0
|
|
|
|
if self.method == 'ortho':
|
|
self._update_ortho()
|
|
else:
|
|
self._update_basic()
|
|
|
|
self.ivars['iterations_left'] = self.ivars['iterations_left'] - 1
|
|
self.ivars['istep'] = self.ivars['istep'] + 1
|
|
|
|
def update_residual(self):
|
|
"""Update residual R from A, B, X, E.
|
|
"""
|
|
mm = _utils.matmul
|
|
self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
|
|
|
|
def update_converged_count(self):
|
|
"""Determine the number of converged eigenpairs using backward stable
|
|
convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
|
|
|
|
Users may redefine this method for custom convergence criteria.
|
|
"""
|
|
# (...) -> int
|
|
prev_count = self.ivars['converged_count']
|
|
tol = self.fparams['tol']
|
|
A_norm = self.fvars['A_norm']
|
|
B_norm = self.fvars['B_norm']
|
|
E, X, R = self.E, self.X, self.R
|
|
rerr = torch.norm(R, 2, (0, )) * (torch.norm(X, 2, (0, )) * (A_norm + E[:X.shape[-1]] * B_norm)) ** -1
|
|
converged = rerr < tol
|
|
count = 0
|
|
for b in converged:
|
|
if not b:
|
|
# ignore convergence of following pairs to ensure
|
|
# strict ordering of eigenpairs
|
|
break
|
|
count += 1
|
|
assert count >= prev_count, 'the number of converged eigenpairs ' \
|
|
'(was {}, got {}) cannot decrease'.format(prev_count, count)
|
|
self.ivars['converged_count'] = count
|
|
self.tvars['rerr'] = rerr
|
|
return count
|
|
|
|
def stop_iteration(self):
|
|
"""Return True to stop iterations.
|
|
|
|
Note that tracker (if defined) can force-stop iterations by
|
|
setting ``worker.bvars['force_stop'] = True``.
|
|
"""
|
|
return (self.bvars.get('force_stop', False)
|
|
or self.ivars['iterations_left'] == 0
|
|
or self.ivars['converged_count'] >= self.iparams['k'])
|
|
|
|
def run(self):
|
|
"""Run LOBPCG iterations.
|
|
|
|
Use this method as a template for implementing LOBPCG
|
|
iteration scheme with custom tracker that is compatible with
|
|
TorchScript.
|
|
"""
|
|
self.update()
|
|
|
|
if not torch.jit.is_scripting() and self.tracker is not None:
|
|
self.call_tracker()
|
|
|
|
while not self.stop_iteration():
|
|
|
|
self.update()
|
|
|
|
if not torch.jit.is_scripting() and self.tracker is not None:
|
|
self.call_tracker()
|
|
|
|
@torch.jit.unused
|
|
def call_tracker(self):
|
|
"""Interface for tracking iteration process in Python mode.
|
|
|
|
Tracking the iteration process is disabled in TorchScript
|
|
mode. In fact, one should specify tracker=None when JIT
|
|
compiling functions using lobpcg.
|
|
"""
|
|
# do nothing when in TorchScript mode
|
|
pass
|
|
|
|
# Internal methods
|
|
|
|
def _update_basic(self):
|
|
"""
|
|
Update or initialize iteration variables when `method == "basic"`.
|
|
"""
|
|
mm = torch.matmul
|
|
ns = self.ivars['converged_end']
|
|
nc = self.ivars['converged_count']
|
|
n = self.iparams['n']
|
|
largest = self.bparams['largest']
|
|
|
|
if self.ivars['istep'] == 0:
|
|
Ri = self._get_rayleigh_ritz_transform(self.X)
|
|
M = _utils.qform(_utils.qform(self.A, self.X), Ri)
|
|
E, Z = _utils.symeig(M, largest)
|
|
self.X[:] = mm(self.X, mm(Ri, Z))
|
|
self.E[:] = E
|
|
np = 0
|
|
self.update_residual()
|
|
nc = self.update_converged_count()
|
|
self.S[..., :n] = self.X
|
|
|
|
W = _utils.matmul(self.iK, self.R)
|
|
self.ivars['converged_end'] = ns = n + np + W.shape[-1]
|
|
self.S[:, n + np:ns] = W
|
|
else:
|
|
S_ = self.S[:, nc:ns]
|
|
Ri = self._get_rayleigh_ritz_transform(S_)
|
|
M = _utils.qform(_utils.qform(self.A, S_), Ri)
|
|
E_, Z = _utils.symeig(M, largest)
|
|
self.X[:, nc:] = mm(S_, mm(Ri, Z[:, :n - nc]))
|
|
self.E[nc:] = E_[:n - nc]
|
|
P = mm(S_, mm(Ri, Z[:, n:2 * n - nc]))
|
|
np = P.shape[-1]
|
|
|
|
self.update_residual()
|
|
nc = self.update_converged_count()
|
|
self.S[..., :n] = self.X
|
|
self.S[:, n:n + np] = P
|
|
W = _utils.matmul(self.iK, self.R[:, nc:])
|
|
|
|
self.ivars['converged_end'] = ns = n + np + W.shape[-1]
|
|
self.S[:, n + np:ns] = W
|
|
|
|
def _update_ortho(self):
|
|
"""
|
|
Update or initialize iteration variables when `method == "ortho"`.
|
|
"""
|
|
mm = torch.matmul
|
|
ns = self.ivars['converged_end']
|
|
nc = self.ivars['converged_count']
|
|
n = self.iparams['n']
|
|
largest = self.bparams['largest']
|
|
|
|
if self.ivars['istep'] == 0:
|
|
Ri = self._get_rayleigh_ritz_transform(self.X)
|
|
M = _utils.qform(_utils.qform(self.A, self.X), Ri)
|
|
E, Z = _utils.symeig(M, largest)
|
|
self.X = mm(self.X, mm(Ri, Z))
|
|
self.update_residual()
|
|
np = 0
|
|
nc = self.update_converged_count()
|
|
self.S[:, :n] = self.X
|
|
W = self._get_ortho(self.R, self.X)
|
|
ns = self.ivars['converged_end'] = n + np + W.shape[-1]
|
|
self.S[:, n + np:ns] = W
|
|
|
|
else:
|
|
S_ = self.S[:, nc:ns]
|
|
# Rayleigh-Ritz procedure
|
|
E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
|
|
|
|
# Update E, X, P
|
|
self.X[:, nc:] = mm(S_, Z[:, :n - nc])
|
|
self.E[nc:] = E_[:n - nc]
|
|
P = mm(S_, mm(Z[:, n - nc:], _utils.basis(_utils.transpose(Z[:n - nc, n - nc:]))))
|
|
np = P.shape[-1]
|
|
|
|
# check convergence
|
|
self.update_residual()
|
|
nc = self.update_converged_count()
|
|
|
|
# update S
|
|
self.S[:, :n] = self.X
|
|
self.S[:, n:n + np] = P
|
|
W = self._get_ortho(self.R[:, nc:], self.S[:, :n + np])
|
|
ns = self.ivars['converged_end'] = n + np + W.shape[-1]
|
|
self.S[:, n + np:ns] = W
|
|
|
|
def _get_rayleigh_ritz_transform(self, S):
|
|
"""Return a transformation matrix that is used in Rayleigh-Ritz
|
|
procedure for reducing a general eigenvalue problem :math:`(S^TAS)
|
|
C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
|
|
S^TAS Ri) Z = Z E` where `C = Ri Z`.
|
|
|
|
.. note:: In the original Rayleight-Ritz procedure in
|
|
[DuerschEtal2018], the problem is formulated as follows::
|
|
|
|
SAS = S^T A S
|
|
SBS = S^T B S
|
|
D = (<diagonal matrix of SBS>) ** -1/2
|
|
R^T R = Cholesky(D SBS D)
|
|
Ri = D R^-1
|
|
solve symeig problem Ri^T SAS Ri Z = Theta Z
|
|
C = Ri Z
|
|
|
|
To reduce the number of matrix products (denoted by empty
|
|
space between matrices), here we introduce element-wise
|
|
products (denoted by symbol `*`) so that the Rayleight-Ritz
|
|
procedure becomes::
|
|
|
|
SAS = S^T A S
|
|
SBS = S^T B S
|
|
d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector
|
|
dd = d d^T # this is 2-d matrix
|
|
R^T R = Cholesky(dd * SBS)
|
|
Ri = R^-1 * d # broadcasting
|
|
solve symeig problem Ri^T SAS Ri Z = Theta Z
|
|
C = Ri Z
|
|
|
|
where `dd` is 2-d matrix that replaces matrix products `D M
|
|
D` with one element-wise product `M * dd`; and `d` replaces
|
|
matrix product `D M` with element-wise product `M *
|
|
d`. Also, creating the diagonal matrix `D` is avoided.
|
|
|
|
Arguments:
|
|
S (Tensor): the matrix basis for the search subspace, size is
|
|
:math:`(m, n)`.
|
|
|
|
Returns:
|
|
Ri (tensor): upper-triangular transformation matrix of size
|
|
:math:`(n, n)`.
|
|
|
|
"""
|
|
B = self.B
|
|
mm = torch.matmul
|
|
SBS = _utils.qform(B, S)
|
|
d_row = SBS.diagonal(0, -2, -1) ** -0.5
|
|
d_col = d_row.reshape(d_row.shape[0], 1)
|
|
R = torch.cholesky((SBS * d_row) * d_col, upper=True)
|
|
# TODO: could use LAPACK ?trtri as R is upper-triangular
|
|
Rinv = torch.inverse(R)
|
|
return Rinv * d_col
|
|
|
|
def _get_svqb(self,
|
|
U, # Tensor
|
|
drop, # bool
|
|
tau # float
|
|
):
|
|
# type: (Tensor, bool, float) -> Tensor
|
|
"""Return B-orthonormal U.
|
|
|
|
.. note:: When `drop` is `False` then `svqb` is based on the
|
|
Algorithm 4 from [DuerschPhD2015] that is a slight
|
|
modification of the corresponding algorithm
|
|
introduced in [StathopolousWu2002].
|
|
|
|
Arguments:
|
|
|
|
U (Tensor) : initial approximation, size is (m, n)
|
|
drop (bool) : when True, drop columns that
|
|
contribution to the `span([U])` is small.
|
|
tau (float) : positive tolerance
|
|
|
|
Returns:
|
|
|
|
U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
|
|
is (m, n1), where `n1 = n` if `drop` is `False,
|
|
otherwise `n1 <= n`.
|
|
|
|
"""
|
|
if torch.numel(U) == 0:
|
|
return U
|
|
UBU = _utils.qform(self.B, U)
|
|
d = UBU.diagonal(0, -2, -1)
|
|
|
|
# Detect and drop exact zero columns from U. While the test
|
|
# `abs(d) == 0` is unlikely to be True for random data, it is
|
|
# possible to construct input data to lobpcg where it will be
|
|
# True leading to a failure (notice the `d ** -0.5` operation
|
|
# in the original algorithm). To prevent the failure, we drop
|
|
# the exact zero columns here and then continue with the
|
|
# original algorithm below.
|
|
nz = torch.where(abs(d) != 0.0)
|
|
assert len(nz) == 1, nz
|
|
if len(nz[0]) < len(d):
|
|
U = U[:, nz[0]]
|
|
if torch.numel(U) == 0:
|
|
return U
|
|
UBU = _utils.qform(self.B, U)
|
|
d = UBU.diagonal(0, -2, -1)
|
|
nz = torch.where(abs(d) != 0.0)
|
|
assert len(nz[0]) == len(d)
|
|
|
|
# The original algorithm 4 from [DuerschPhD2015].
|
|
d_col = (d ** -0.5).reshape(d.shape[0], 1)
|
|
DUBUD = (UBU * d_col) * _utils.transpose(d_col)
|
|
E, Z = _utils.symeig(DUBUD, eigenvectors=True)
|
|
t = tau * abs(E).max()
|
|
if drop:
|
|
keep = torch.where(E > t)
|
|
assert len(keep) == 1, keep
|
|
E = E[keep[0]]
|
|
Z = Z[:, keep[0]]
|
|
d_col = d_col[keep[0]]
|
|
else:
|
|
E[(torch.where(E < t))[0]] = t
|
|
|
|
return torch.matmul(U * _utils.transpose(d_col), Z * E ** -0.5)
|
|
|
|
def _get_ortho(self, U, V):
|
|
"""Return B-orthonormal U with columns are B-orthogonal to V.
|
|
|
|
.. note:: When `bparams["ortho_use_drop"] == False` then
|
|
`_get_ortho` is based on the Algorithm 3 from
|
|
[DuerschPhD2015] that is a slight modification of
|
|
the corresponding algorithm introduced in
|
|
[StathopolousWu2002]. Otherwise, the method
|
|
implements Algorithm 6 from [DuerschPhD2015]
|
|
|
|
.. note:: If all U columns are B-collinear to V then the
|
|
returned tensor U will be empty.
|
|
|
|
Arguments:
|
|
|
|
U (Tensor) : initial approximation, size is (m, n)
|
|
V (Tensor) : B-orthogonal external basis, size is (m, k)
|
|
|
|
Returns:
|
|
|
|
U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
|
|
such that :math:`V^T B U=0`, size is (m, n1),
|
|
where `n1 = n` if `drop` is `False, otherwise
|
|
`n1 <= n`.
|
|
"""
|
|
mm = torch.matmul
|
|
mm_B = _utils.matmul
|
|
m = self.iparams['m']
|
|
tau_ortho = self.fparams['ortho_tol']
|
|
tau_drop = self.fparams['ortho_tol_drop']
|
|
tau_replace = self.fparams['ortho_tol_replace']
|
|
i_max = self.iparams['ortho_i_max']
|
|
j_max = self.iparams['ortho_j_max']
|
|
# when use_drop==True, enable dropping U columns that have
|
|
# small contribution to the `span([U, V])`.
|
|
use_drop = self.bparams['ortho_use_drop']
|
|
|
|
# clean up variables from the previous call
|
|
for vkey in list(self.fvars.keys()):
|
|
if vkey.startswith('ortho_') and vkey.endswith('_rerr'):
|
|
self.fvars.pop(vkey)
|
|
self.ivars.pop('ortho_i', 0)
|
|
self.ivars.pop('ortho_j', 0)
|
|
|
|
BV_norm = torch.norm(mm_B(self.B, V))
|
|
BU = mm_B(self.B, U)
|
|
VBU = mm(_utils.transpose(V), BU)
|
|
i = j = 0
|
|
stats = ''
|
|
for i in range(i_max):
|
|
U = U - mm(V, VBU)
|
|
drop = False
|
|
tau_svqb = tau_drop
|
|
for j in range(j_max):
|
|
if use_drop:
|
|
U = self._get_svqb(U, drop, tau_svqb)
|
|
drop = True
|
|
tau_svqb = tau_replace
|
|
else:
|
|
U = self._get_svqb(U, False, tau_replace)
|
|
if torch.numel(U) == 0:
|
|
# all initial U columns are B-collinear to V
|
|
self.ivars['ortho_i'] = i
|
|
self.ivars['ortho_j'] = j
|
|
return U
|
|
BU = mm_B(self.B, U)
|
|
UBU = mm(_utils.transpose(U), BU)
|
|
U_norm = torch.norm(U)
|
|
BU_norm = torch.norm(BU)
|
|
R = UBU - torch.eye(UBU.shape[-1],
|
|
device=UBU.device,
|
|
dtype=UBU.dtype)
|
|
R_norm = torch.norm(R)
|
|
# https://github.com/pytorch/pytorch/issues/33810 workaround:
|
|
rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
|
|
vkey = 'ortho_UBUmI_rerr[{}, {}]'.format(i, j)
|
|
self.fvars[vkey] = rerr
|
|
if rerr < tau_ortho:
|
|
break
|
|
VBU = mm(_utils.transpose(V), BU)
|
|
VBU_norm = torch.norm(VBU)
|
|
U_norm = torch.norm(U)
|
|
rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
|
|
vkey = 'ortho_VBU_rerr[{}]'.format(i)
|
|
self.fvars[vkey] = rerr
|
|
if rerr < tau_ortho:
|
|
break
|
|
if m < U.shape[-1] + V.shape[-1]:
|
|
# TorchScript needs the class var to be assigned to a local to
|
|
# do optional type refinement
|
|
B = self.B
|
|
assert B is not None
|
|
raise ValueError(
|
|
'Overdetermined shape of U:'
|
|
' #B-cols(={}) >= #U-cols(={}) + #V-cols(={}) must hold'
|
|
.format(B.shape[-1], U.shape[-1], V.shape[-1]))
|
|
self.ivars['ortho_i'] = i
|
|
self.ivars['ortho_j'] = j
|
|
return U
|
|
|
|
|
|
# Calling tracker is separated from LOBPCG definitions because
|
|
# TorchScript does not support user-defined callback arguments:
|
|
LOBPCG_call_tracker_orig = LOBPCG.call_tracker
|
|
def LOBPCG_call_tracker(self):
|
|
self.tracker(self)
|