Prefer mT and mH over transpose(-2, -1) and transpose(-2, -1).conj() (#64181)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64181

This PR replaces all the calls to:
- `transpose(-2, -1)` or `transpose(-1, -2)` by `mT()` in C++ and `mT` in Python
- `conj().transpose(-2, -1)` or `transpose(-2, -1).conj()` or `conj().transpose(-1, -2)` or `transpose(-1, -2).conj()` by `mH()` in C++ and `mH` in Python.

It also simplifies two pieces of code, and fixes one bug where a pair
of parentheses were missing in the function `make_symmetric_matrices`.

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D31692896

Pulled By: anjali411

fbshipit-source-id: e9112c42343663d442dc5bd53ff2b492094b434a
This commit is contained in:
lezcano
2021-10-18 13:00:48 -07:00
committed by Facebook GitHub Bot
parent 44fd312604
commit 0974215c4d
32 changed files with 302 additions and 305 deletions

View File

@ -157,7 +157,7 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
V = Vh.conj().transpose(-2, -1)
V = Vh.mH
V = Q.matmul(V)
else:
Q = get_approximate_basis(A, q, niter=niter, M=M)
@ -171,7 +171,7 @@ def _svd_lowrank(A: Tensor, q: Optional[int] = 6, niter: Optional[int] = 2,
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
V = Vh.conj().transpose(-2, -1)
V = Vh.mH
U = Q.matmul(U)
return U, S, V