mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] fix float32 error on mps, in linalg.matrix_rank and linalg.pinv (#114771)
Fixes #114285 (However, still have NotImplementedError ```NotImplementedError: The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.```) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114771 Approved by: https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
a72190fd51
commit
d444a3b443
1
.gitignore
vendored
1
.gitignore
vendored
@ -126,6 +126,7 @@ env
|
||||
.circleci/scripts/COMMIT_MSG
|
||||
scripts/release_notes/*.json
|
||||
sccache-stats*.json
|
||||
lint.json
|
||||
|
||||
# These files get copied over on invoking setup.py
|
||||
torchgen/packaged/*
|
||||
|
@ -444,7 +444,12 @@ std::tuple<Tensor, Tensor> get_atol_rtol(
|
||||
const optional<Tensor>& atol_opt,
|
||||
const optional<Tensor>& rtol_opt,
|
||||
const c10::string_view function_name) {
|
||||
auto options = input.options().dtype(ScalarType::Double);
|
||||
auto options = input.options();
|
||||
if (input.device().type() == kMetal || input.device().type() == kMPS) {
|
||||
options = options.dtype(ScalarType::Float);
|
||||
} else {
|
||||
options = options.dtype(ScalarType::Double);
|
||||
}
|
||||
auto atol = atol_opt.has_value() ? atol_opt.value() : at::zeros({}, options);
|
||||
checkNotComplexTolerance(atol, function_name, "atol");
|
||||
Tensor rtol;
|
||||
@ -465,7 +470,7 @@ std::tuple<Tensor, Tensor> get_atol_rtol(
|
||||
const Tensor& input,
|
||||
optional<double> atol_opt,
|
||||
optional<double> rtol_opt) {
|
||||
double atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
|
||||
auto atol = atol_opt.has_value() ? atol_opt.value() : 0.0;
|
||||
c10::SymFloat rtol;
|
||||
if (rtol_opt.has_value()) {
|
||||
rtol = rtol_opt.value();
|
||||
@ -476,7 +481,12 @@ std::tuple<Tensor, Tensor> get_atol_rtol(
|
||||
? 0.0
|
||||
: default_rtol;
|
||||
}
|
||||
auto options = input.options().dtype(ScalarType::Double);
|
||||
auto options = input.options();
|
||||
if (input.device().type() == kMetal || input.device().type() == kMPS) {
|
||||
options = options.dtype(ScalarType::Float);
|
||||
} else {
|
||||
options = options.dtype(ScalarType::Double);
|
||||
}
|
||||
auto atol_tensor = at::full({}, atol, options);
|
||||
auto rtol_tensor = at::full({}, rtol, options);
|
||||
return std::make_tuple(atol_tensor, rtol_tensor);
|
||||
@ -545,7 +555,12 @@ Tensor linalg_pinv(const Tensor& input, optional<double> atol, optional<double>
|
||||
Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
|
||||
// For NumPy compatibility the rcond argument is used as relative tolerance
|
||||
checkNotComplexTolerance(rcond, "torch.linalg.pinv", "rcond");
|
||||
auto options = input.options().dtype(ScalarType::Double);
|
||||
auto options = input.options();
|
||||
if (input.device().type() == kMetal || input.device().type() == kMPS) {
|
||||
options = options.dtype(ScalarType::Float);
|
||||
} else {
|
||||
options = options.dtype(ScalarType::Double);
|
||||
}
|
||||
return at::linalg_pinv(input, at::zeros({}, options), rcond, hermitian);
|
||||
}
|
||||
|
||||
|
145
test/test_mps.py
145
test/test_mps.py
@ -189,6 +189,12 @@ def mps_ops_grad_modifier(ops):
|
||||
'msort': [torch.float16],
|
||||
}
|
||||
|
||||
ON_MPS_XFAILLIST = {
|
||||
# Failures due to lack of implementation of downstream functions on MPS backend
|
||||
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
|
||||
'linalg.matrix_rank': None,
|
||||
}
|
||||
|
||||
def addDecorator(op, d) -> None:
|
||||
op.decorators = list(op.decorators) if op.decorators is not None else []
|
||||
op.decorators.append(d)
|
||||
@ -205,6 +211,11 @@ def mps_ops_grad_modifier(ops):
|
||||
unittest.skip,
|
||||
dtypes=SKIPLIST_GRAD[key]))
|
||||
|
||||
if key in ON_MPS_XFAILLIST:
|
||||
addDecorator(op, DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
dtypes=ON_MPS_XFAILLIST[key]))
|
||||
|
||||
if key in MACOS_12_3_XFAILLIST_GRAD and (not torch.backends.mps.is_macos13_or_newer()):
|
||||
addDecorator(op, DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
@ -722,7 +733,6 @@ def mps_ops_modifier(ops):
|
||||
'nn.functional.norm': None,
|
||||
'ormqr': None,
|
||||
'pca_lowrank': None,
|
||||
'pinverse': None,
|
||||
'qr': None,
|
||||
'quantile': None,
|
||||
'rsub': None,
|
||||
@ -792,9 +802,7 @@ def mps_ops_modifier(ops):
|
||||
'softmaxwith_dtype': None,
|
||||
'float_power': None,
|
||||
'full_like': None,
|
||||
'linalg.matrix_rank': None,
|
||||
'linalg.matrix_rankhermitian': None,
|
||||
'linalg.pinv': None,
|
||||
'linalg.pinvhermitian': None,
|
||||
'nonzero_static': None,
|
||||
|
||||
@ -918,6 +926,12 @@ def mps_ops_modifier(ops):
|
||||
'logit': [torch.float16],
|
||||
}
|
||||
|
||||
ON_MPS_XFAILLIST = {
|
||||
# Failures due to lack of implementation of downstream functions on MPS backend
|
||||
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
|
||||
'linalg.matrix_rank': None,
|
||||
}
|
||||
|
||||
EMPTY_OPS_SKIPLIST = {
|
||||
# Fill tensors with uninitialized data, causing mismatch with CPU.
|
||||
# They occasionally match, thus skipping them.
|
||||
@ -954,7 +968,7 @@ def mps_ops_modifier(ops):
|
||||
dtypes=EMPTY_OPS_SKIPLIST[key]))
|
||||
if key in SKIPLIST:
|
||||
addDecorator(op, DecorateInfo(unittest.skip("Skipped!"), dtypes=SKIPLIST[key]))
|
||||
for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST]:
|
||||
for xfaillist in [UNIMPLEMENTED_XFAILLIST, UNDEFINED_XFAILLIST, ON_MPS_XFAILLIST]:
|
||||
if key in xfaillist:
|
||||
addDecorator(op, DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
@ -8729,6 +8743,129 @@ class TestLinalgMPS(TestCaseMPS):
|
||||
m2 = torch.randn(25, device=device).to(dtype)
|
||||
self._test_addr(torch.addr, M, m1, m2, beta=0)
|
||||
|
||||
def test_matrix_rank(self, device="mps", dtype=torch.float32):
|
||||
matrix_rank = torch.linalg.matrix_rank
|
||||
|
||||
def run_test(shape0, shape1, batch):
|
||||
a = torch.randn(*batch, shape0, shape1, dtype=dtype, device=device)
|
||||
rank_a = matrix_rank(a)
|
||||
|
||||
self.assertEqual(rank_a, matrix_rank(a.mH))
|
||||
aaH = torch.matmul(a, a.mH)
|
||||
rank_aaH = matrix_rank(aaH)
|
||||
rank_aaH_hermitian = matrix_rank(aaH, hermitian=True)
|
||||
self.assertEqual(rank_aaH, rank_aaH_hermitian)
|
||||
aHa = torch.matmul(a.mH, a)
|
||||
self.assertEqual(matrix_rank(aHa), matrix_rank(aHa, hermitian=True))
|
||||
|
||||
# check against NumPy
|
||||
self.assertEqual(rank_a, np.linalg.matrix_rank(a.cpu().numpy()))
|
||||
self.assertEqual(matrix_rank(a, 0.01), np.linalg.matrix_rank(a.cpu().numpy(), 0.01))
|
||||
|
||||
self.assertEqual(rank_aaH, np.linalg.matrix_rank(aaH.cpu().numpy()))
|
||||
self.assertEqual(matrix_rank(aaH, 0.01), np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01))
|
||||
|
||||
# hermitian flag for NumPy was added in 1.14.0
|
||||
if np.lib.NumpyVersion(np.__version__) >= '1.14.0':
|
||||
self.assertEqual(rank_aaH_hermitian,
|
||||
np.linalg.matrix_rank(aaH.cpu().numpy(), hermitian=True))
|
||||
self.assertEqual(matrix_rank(aaH, 0.01, True),
|
||||
np.linalg.matrix_rank(aaH.cpu().numpy(), 0.01, True))
|
||||
|
||||
# check out= variant
|
||||
out = torch.empty(a.shape[:-2], dtype=torch.int64, device=device)
|
||||
ans = matrix_rank(a, out=out)
|
||||
self.assertEqual(ans, out)
|
||||
self.assertEqual(ans, rank_a)
|
||||
|
||||
shapes = (3, 13)
|
||||
batches = ((), (0, ), (4, ), (3, 5, ))
|
||||
for (shape0, shape1), batch in zip(itertools.product(shapes, reversed(shapes)), batches):
|
||||
# escape only when NotImplementedError of downstream function is raised
|
||||
# TODO: remove this once the required function is implemented
|
||||
try:
|
||||
run_test(shape0, shape1, batch)
|
||||
except NotImplementedError as e:
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"The operator 'aten::_linalg_svd.U' is not currently implemented for the MPS device."):
|
||||
raise e
|
||||
|
||||
def test_pinv(self, device="mps", dtype=torch.float32, precision=1e-4):
|
||||
from torch.testing._internal.common_utils import random_hermitian_pd_matrix
|
||||
|
||||
def run_test_main(A, hermitian):
|
||||
# Testing against definition for pseudo-inverses
|
||||
A_pinv = torch.linalg.pinv(A, hermitian=hermitian)
|
||||
np_A = A.cpu().numpy()
|
||||
np_A_pinv = A_pinv.cpu().numpy()
|
||||
if A.numel() > 0:
|
||||
self.assertEqual(A, np_A @ np_A_pinv @ np_A, atol=precision, rtol=precision)
|
||||
self.assertEqual(A_pinv, np_A_pinv @ np_A @ np_A_pinv, atol=precision, rtol=precision)
|
||||
self.assertEqual(np_A @ np_A_pinv, (np_A @ np_A_pinv).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
|
||||
self.assertEqual(np_A_pinv @ np_A, (np_A_pinv @ np_A).conj().swapaxes(-2, -1), atol=precision, rtol=precision)
|
||||
else:
|
||||
self.assertEqual(A.shape, A_pinv.shape[:-2] + (A_pinv.shape[-1], A_pinv.shape[-2]))
|
||||
|
||||
# Check out= variant
|
||||
out = torch.empty_like(A_pinv)
|
||||
ans = torch.linalg.pinv(A, hermitian=hermitian, out=out)
|
||||
self.assertEqual(ans, out)
|
||||
self.assertEqual(ans, A_pinv)
|
||||
|
||||
def run_test_numpy(A, hermitian):
|
||||
# Check against NumPy output
|
||||
# Test float rcond, and specific value for each matrix
|
||||
rconds = [float(torch.rand(1)), ]
|
||||
# Test different types of rcond tensor
|
||||
for rcond_type in MPS_DTYPES:
|
||||
rconds.append(torch.rand(A.shape[:-2], dtype=torch.float32, device=device).to(rcond_type))
|
||||
# Test broadcasting of rcond
|
||||
if A.ndim > 2:
|
||||
rconds.append(torch.rand(A.shape[-3], device=device))
|
||||
for rcond in rconds:
|
||||
actual = torch.linalg.pinv(A, rcond=rcond, hermitian=hermitian)
|
||||
torch_rtol = torch.linalg.pinv(A, rtol=rcond, hermitian=hermitian)
|
||||
self.assertEqual(actual, torch_rtol, atol=precision, rtol=precision)
|
||||
numpy_rcond = rcond if isinstance(rcond, float) else rcond.cpu().numpy()
|
||||
expected = np.linalg.pinv(A.cpu().numpy(), rcond=numpy_rcond, hermitian=hermitian)
|
||||
self.assertEqual(actual, expected, atol=precision, rtol=precision)
|
||||
|
||||
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
|
||||
(3, 2), (5, 3, 2), (2, 5, 3, 2), # fat matrices
|
||||
(2, 3), (5, 2, 3), (2, 5, 2, 3), # thin matrices
|
||||
(0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices
|
||||
A = torch.randn(*sizes, dtype=dtype, device=device)
|
||||
hermitian = False
|
||||
run_test_main(A, hermitian)
|
||||
run_test_numpy(A, hermitian)
|
||||
|
||||
# Check hermitian = True
|
||||
for sizes in [(5, 5), (3, 5, 5), (3, 2, 5, 5), # square matrices
|
||||
(0, 0), (3, 0, 0), ]: # zero numel square matrices
|
||||
A = random_hermitian_pd_matrix(sizes[-1], *sizes[:-2], dtype=dtype, device=device)
|
||||
hermitian = True
|
||||
# escape only when NotImplementedError of downstream function is raised
|
||||
# TODO: remove this once the required function is implemented
|
||||
try:
|
||||
run_test_main(A, hermitian)
|
||||
except NotImplementedError as e:
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
|
||||
raise e
|
||||
try:
|
||||
run_test_numpy(A, hermitian)
|
||||
except NotImplementedError as e:
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TestGatherScatter(TestCaseMPS):
|
||||
def test_slicing_with_step(self):
|
||||
# Slicing with step
|
||||
|
Reference in New Issue
Block a user