mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Speedup argmax
/argmin
(#159524)
By using efficient `threadgroup_arg[max|min]` primitives. - Fixed bug in `simd_argmax` when result of the `simd_ballot` were prematurely cast to `ushort` and adjusted unit test - Fixed nan handling in compiled argmax, but can't reliably test it as MPS(eager) implementaiton of argmax is buggy Now according to `bench_mps_ops.py` `max(x, dim=0)` is reliably faster than eager implementaiton: ``` [--------------------------------------------------------------------------------------------- --------------------------------------------------------------------------------------------] | eager-512x512 | compile-512x512 | eager-1024x1024 | compile-1024x1024 | eager-2048x2048 | compile-2048x2048 | eager-4096x4096 | compile-4096x4096 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- max (torch.float16) | 285.8 | 272.2 | 422.3 | 354.5 | 721.6 | 683.5 | 2224.0 | 1979.1 max (torch.float32) | 300.2 | 267.0 | 389.6 | 342.5 | 769.4 | 682.6 | 2995.7 | 2609.8 max (torch.int32) | 299.6 | 275.4 | 390.0 | 361.7 | 758.7 | 686.1 | 3103.4 | 2646.5 max (torch.int64) | 297.5 | 275.5 | 417.0 | 382.1 | 856.1 | 722.6 | 5467.7 | 3156.8 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/159524 Approved by: https://github.com/Skylion007, https://github.com/dcci ghstack dependencies: #158990
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2e02585b8
commit
f946b25865
@ -12504,10 +12504,11 @@ class TestMetalLibrary(TestCaseMPS):
|
||||
if not dtype.is_floating_point:
|
||||
return
|
||||
|
||||
x[5] = torch.nan
|
||||
idx = 25
|
||||
x[idx] = torch.nan
|
||||
lib.do_max(z0, z1, x)
|
||||
self.assertTrue(z0.isnan().all().item(), "results are {z0}, but all elements shold have been nan")
|
||||
self.assertTrue((z1 == 5).all().item(), "results are {z1}, but all elements shold have been 5")
|
||||
self.assertTrue(z0.isnan().all().item(), f"results are {z0}, but all elements shold have been nan")
|
||||
self.assertTrue((z1 == idx).all().item(), f"results are {z1}, but all elements shold have been {idx}")
|
||||
|
||||
@parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16])
|
||||
def test_atomic_add(self, dtype):
|
||||
|
Reference in New Issue
Block a user