[MPS] Add slow version of kthvalue (#161817)

Which heavily borrows implementation logic from `topk`
As this method is non-deterministic, modified the logic for cpu-ops indices comparison with just an equality statement, as by default random numbers picked for input tensor allow for quite a lot of overlaps
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161817
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga
2025-08-29 16:25:46 -07:00
committed by PyTorch MergeBot
parent c1e504ec2f
commit 7c30a9d7fc
4 changed files with 114 additions and 1 deletions

View File

@ -12303,6 +12303,15 @@ class TestConsistency(TestCaseMPS):
if op.name in "grid_sampler_3d":
atol, rtol = 1e-4, 1e-4
if op.name == "kthvalue":
self.assertEqual(cpu_out[0], mps_out[0], atol=atol, rtol=rtol)
# kthvalue is non-deterministic if input has repeated values
dim = cpu_args[2] if len(cpu_args) > 2 else -1
keep_dim = cpu_args[3] if len(cpu_args) > 3 else False
values = torch.gather(mps_sample.input, dim, mps_out[1] if keep_dim else mps_out[1].unsqueeze(dim))
self.assertEqual(values if keep_dim else values.squeeze(dim), mps_out[0])
continue
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)