mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add slow version of kthvalue
(#161817)
Which heavily borrows implementation logic from `topk` As this method is non-deterministic, modified the logic for cpu-ops indices comparison with just an equality statement, as by default random numbers picked for input tensor allow for quite a lot of overlaps Pull Request resolved: https://github.com/pytorch/pytorch/pull/161817 Approved by: https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
c1e504ec2f
commit
7c30a9d7fc
@ -12303,6 +12303,15 @@ class TestConsistency(TestCaseMPS):
|
||||
if op.name in "grid_sampler_3d":
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
|
||||
if op.name == "kthvalue":
|
||||
self.assertEqual(cpu_out[0], mps_out[0], atol=atol, rtol=rtol)
|
||||
# kthvalue is non-deterministic if input has repeated values
|
||||
dim = cpu_args[2] if len(cpu_args) > 2 else -1
|
||||
keep_dim = cpu_args[3] if len(cpu_args) > 3 else False
|
||||
values = torch.gather(mps_sample.input, dim, mps_out[1] if keep_dim else mps_out[1].unsqueeze(dim))
|
||||
self.assertEqual(values if keep_dim else values.squeeze(dim), mps_out[0])
|
||||
continue
|
||||
|
||||
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
|
||||
|
||||
@ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
|
||||
|
Reference in New Issue
Block a user