mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
Prefer mT and mH over transpose(-2, -1) and transpose(-2, -1).conj() (#64181)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64181 This PR replaces all the calls to: - `transpose(-2, -1)` or `transpose(-1, -2)` by `mT()` in C++ and `mT` in Python - `conj().transpose(-2, -1)` or `transpose(-2, -1).conj()` or `conj().transpose(-1, -2)` or `transpose(-1, -2).conj()` by `mH()` in C++ and `mH` in Python. It also simplifies two pieces of code, and fixes one bug where a pair of parentheses were missing in the function `make_symmetric_matrices`. Test Plan: Imported from OSS Reviewed By: H-Huang Differential Revision: D31692896 Pulled By: anjali411 fbshipit-source-id: e9112c42343663d442dc5bd53ff2b492094b434a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
44fd312604
commit
0974215c4d
@ -157,7 +157,7 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
|
||||
assert B_t.shape[-1] == q, (B_t.shape, q)
|
||||
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
|
||||
U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
|
||||
V = Vh.conj().transpose(-2, -1)
|
||||
V = Vh.mH
|
||||
V = Q.matmul(V)
|
||||
else:
|
||||
Q = get_approximate_basis(A, q, niter=niter, M=M)
|
||||
@ -171,7 +171,7 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
|
||||
assert B_t.shape[-1] == n, (B_t.shape, n)
|
||||
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
|
||||
U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
|
||||
V = Vh.conj().transpose(-2, -1)
|
||||
V = Vh.mH
|
||||
U = Q.matmul(U)
|
||||
|
||||
return U, S, V
|
||||
|
||||
Reference in New Issue
Block a user