mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
MPS: TopK raise an error if K>16 (#79677)
* Error out in TopK when k>16. * Add a test case too. Fixes #78915 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79677 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
270c518be0
commit
355a1c8c3f
@ -300,6 +300,8 @@ TORCH_IMPL_FUNC(topk_out_mps)
|
|||||||
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
|
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
|
||||||
"selected index k out of range");
|
"selected index k out of range");
|
||||||
|
|
||||||
|
TORCH_CHECK( k <= 16 , "Currently topk on mps works only for k<=16 ");
|
||||||
|
|
||||||
if (self.dim() == 0 && self.numel() == 1)
|
if (self.dim() == 0 && self.numel() == 1)
|
||||||
{
|
{
|
||||||
values.copy_(self);
|
values.copy_(self);
|
||||||
|
@ -2801,6 +2801,16 @@ class TestNLLLoss(TestCase):
|
|||||||
|
|
||||||
helper(3, 3)
|
helper(3, 3)
|
||||||
|
|
||||||
|
def test_assert_topk(self):
|
||||||
|
# here the k > 16 raises an error as expected
|
||||||
|
with self.assertRaisesRegex(RuntimeError, "Currently topk on mps works only for k<=16"):
|
||||||
|
xs = torch.arange(30).to('mps')
|
||||||
|
xs.topk(30)
|
||||||
|
# for k <= 16 it works fine
|
||||||
|
ys_cpu = torch.arange(30)
|
||||||
|
ys_mps = ys_cpu.to('mps')
|
||||||
|
self.assertEqual(ys_cpu.topk(16), ys_mps.topk(16))
|
||||||
|
|
||||||
def test_topk(self):
|
def test_topk(self):
|
||||||
def helper(shape):
|
def helper(shape):
|
||||||
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
|
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
|
||||||
|
Reference in New Issue
Block a user