[MPS] Speedup argmax/argmin (#159524)

By using efficient `threadgroup_arg[max|min]` primitives.
- Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test
- Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy

Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton:
```
[---------------------------------------------------------------------------------------------  --------------------------------------------------------------------------------------------]
                           |  eager-512x512  |  compile-512x512  |  eager-1024x1024  |  compile-1024x1024  |  eager-2048x2048  |  compile-2048x2048  |  eager-4096x4096  |  compile-4096x4096
1 threads: ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      max (torch.float16)  |      285.8      |       272.2       |       422.3       |        354.5        |       721.6       |        683.5        |       2224.0      |        1979.1
      max (torch.float32)  |      300.2      |       267.0       |       389.6       |        342.5        |       769.4       |        682.6        |       2995.7      |        2609.8
      max (torch.int32)    |      299.6      |       275.4       |       390.0       |        361.7        |       758.7       |        686.1        |       3103.4      |        2646.5
      max (torch.int64)    |      297.5      |       275.5       |       417.0       |        382.1        |       856.1       |        722.6        |       5467.7      |        3156.8

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159524
Approved by: https://github.com/Skylion007, https://github.com/dcci
ghstack dependencies: #158990
This commit is contained in:
Nikita Shulga
2025-07-31 08:42:56 -07:00
committed by PyTorch MergeBot
parent d2e02585b8
commit f946b25865
3 changed files with 79 additions and 47 deletions

View File

@ -12504,10 +12504,11 @@ class TestMetalLibrary(TestCaseMPS):
if not dtype.is_floating_point:
return
x[5] = torch.nan
idx = 25
x[idx] = torch.nan
lib.do_max(z0, z1, x)
self.assertTrue(z0.isnan().all().item(), "results are {z0}, but all elements shold have been nan")
self.assertTrue((z1 == 5).all().item(), "results are {z1}, but all elements shold have been 5")
self.assertTrue(z0.isnan().all().item(), f"results are {z0}, but all elements shold have been nan")
self.assertTrue((z1 == idx).all().item(), f"results are {z1}, but all elements shold have been {idx}")
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16])
def test_atomic_add(self, dtype):