[MPS] Fix addmm (#116547)

Remove weird logic for designating matrices as transposed if sizes match(which always true if square matrices are multiplied with each other), which resulted in `torch.addmm` returns transposed matrix compared to `torch.mm`, see below:
```
% python -c "import torch;torch.set_default_device('mps');a=torch.eye(2);b=torch.arange(4.0).reshape(2, 2);print(a@b);print(torch.addmm(torch.zeros(2, 2), a,b))"
tensor([[0., 1.],
        [2., 3.]], device='mps:0')
tensor([[0., 2.],
        [1., 3.]], device='mps:0')
```

Fixes introduced to `torch.mm` in https://github.com/pytorch/pytorch/pull/77462 suggests that this is not needed

Modify `sample_inputs_addmm` to test `torch.addmm` with square matrices, but skip this config for `test_autograd_dense_output_addmm`, see https://github.com/pytorch/pytorch/issues/116565

TODO: probably tweak tolerances, as `test_output_match_addmm_cpu_float16` fails with 2x2 matrices, but passes using 3x3 ones with errors slightly exceeding the tolerance

Fixes https://github.com/pytorch/pytorch/issues/116331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116547
Approved by: https://github.com/albanD, https://github.com/Skylion007
This commit is contained in:
Nikita Shulga
2023-12-30 15:45:46 -08:00
committed by PyTorch MergeBot
parent aef06c316b
commit 4bfaa6bc25
3 changed files with 11 additions and 98 deletions

View File

@ -2789,6 +2789,10 @@ class TestSparseCSR(TestCase):
for sample in samples:
a = sample.args[0].relu().to_sparse_csr()
if sample.args[0].shape == sample.args[1].shape:
import warnings
warnings.warn("Broken for square matrices, see https://github.com/pytorch/pytorch/issues/116565")
continue
# This path tests the autograd path wrt dense inputs
for addmm in [torch.addmm, torch.sparse.addmm]: