Allow large inputs to svd_lowrank. Fix inaccuracy in torch.svd docs. (#47440)

Summary:
As in title.

Fixes https://github.com/pytorch/pytorch/issues/42062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47440

Reviewed By: bdhirsh

Differential Revision: D24790628

Pulled By: mruberry

fbshipit-source-id: 1442eb884fbe4ffe6d9c78a4d0186dd0b1482c9c
This commit is contained in:
Pearu Peterson
2020-11-09 21:02:57 -08:00
committed by Facebook GitHub Bot
parent 52fe73a39e
commit c8a42c32a1
2 changed files with 17 additions and 7 deletions

View File

@ -143,15 +143,19 @@ def _svd_lowrank(A, q=6, niter=2, M=None):
# Algorithm 5.1 in Halko et al 2009, slightly modified to reduce
# the number conjugate and transpose operations
if m < n:
# computing the SVD approximation of a transpose in order to
# keep B shape minimal
if m < n or n > q:
# computing the SVD approximation of a transpose in
# order to keep B shape minimal (the m < n case) or the V
# shape small (the n > q case)
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = _utils.conjugate(Q)
if M is None:
B_t = matmul(A, Q_c)
else:
B_t = matmul(A, Q_c) - matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, V = torch.svd(B_t)
V = Q.matmul(V)
else:
@ -161,7 +165,11 @@ def _svd_lowrank(A, q=6, niter=2, M=None):
B = matmul(A_t, Q_c)
else:
B = matmul(A_t, Q_c) - matmul(M_t, Q_c)
U, S, V = torch.svd(_utils.transpose(B))
B_t = _utils.transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, V = torch.svd(B_t)
U = Q.matmul(U)
return U, S, V