mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow large inputs to svd_lowrank. Fix inaccuracy in torch.svd docs. (#47440)
Summary: As in title. Fixes https://github.com/pytorch/pytorch/issues/42062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/47440 Reviewed By: bdhirsh Differential Revision: D24790628 Pulled By: mruberry fbshipit-source-id: 1442eb884fbe4ffe6d9c78a4d0186dd0b1482c9c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
52fe73a39e
commit
c8a42c32a1
@ -143,15 +143,19 @@ def _svd_lowrank(A, q=6, niter=2, M=None):
|
||||
|
||||
# Algorithm 5.1 in Halko et al 2009, slightly modified to reduce
|
||||
# the number conjugate and transpose operations
|
||||
if m < n:
|
||||
# computing the SVD approximation of a transpose in order to
|
||||
# keep B shape minimal
|
||||
if m < n or n > q:
|
||||
# computing the SVD approximation of a transpose in
|
||||
# order to keep B shape minimal (the m < n case) or the V
|
||||
# shape small (the n > q case)
|
||||
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
|
||||
Q_c = _utils.conjugate(Q)
|
||||
if M is None:
|
||||
B_t = matmul(A, Q_c)
|
||||
else:
|
||||
B_t = matmul(A, Q_c) - matmul(M, Q_c)
|
||||
assert B_t.shape[-2] == m, (B_t.shape, m)
|
||||
assert B_t.shape[-1] == q, (B_t.shape, q)
|
||||
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
|
||||
U, S, V = torch.svd(B_t)
|
||||
V = Q.matmul(V)
|
||||
else:
|
||||
@ -161,7 +165,11 @@ def _svd_lowrank(A, q=6, niter=2, M=None):
|
||||
B = matmul(A_t, Q_c)
|
||||
else:
|
||||
B = matmul(A_t, Q_c) - matmul(M_t, Q_c)
|
||||
U, S, V = torch.svd(_utils.transpose(B))
|
||||
B_t = _utils.transpose(B)
|
||||
assert B_t.shape[-2] == q, (B_t.shape, q)
|
||||
assert B_t.shape[-1] == n, (B_t.shape, n)
|
||||
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
|
||||
U, S, V = torch.svd(B_t)
|
||||
U = Q.matmul(U)
|
||||
|
||||
return U, S, V
|
||||
|
Reference in New Issue
Block a user