diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 4c4b983e8bff..d729f04fabec 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -193,11 +193,23 @@ void index_put_kernel_impl(TensorIterator& iter, const IntArrayRef index_size, c }); } -static void index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, iter.dtype(), "index_cuda", [&] { - using dtype = OpaqueType; - index_kernel_impl(iter, index_size, index_stride); - }); +static void index_kernel( + TensorIteratorBase& iter, + const IntArrayRef index_size, + const IntArrayRef index_stride) { + AT_DISPATCH_V2( + iter.dtype(), + "index_cuda", + AT_WRAP([&] { + using dtype = OpaqueType; + index_kernel_impl(iter, index_size, index_stride); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); } static void index_fill_kernel( diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index e94bf27a9780..d164010cda92 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -66,6 +66,7 @@ from torch.testing._internal.common_utils import ( TestCase, xfailIfTorchDynamo, ) +from torch.testing._internal.common_dtype import all_types_complex_float8_and from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.inductor_utils import GPU_TYPE @@ -159,11 +160,16 @@ class FakeTensorTest(TestCase): TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile" ) @unittest.skipIf(not RUN_CUDA, "requires cuda") - def test_index_cuda_with_cpu(self): + @parametrize( + "dtype", + all_types_complex_float8_and(), + ) + def test_index_cuda_with_cpu(self, dtype): with FakeTensorMode(): - x = torch.rand([2048], device="cuda") + x = torch.ones([2048], device="cuda", dtype=dtype) out = x[torch.zeros([36], dtype=torch.int64)] self.checkType(out, "cuda", [36]) + self.assertEqual(out.dtype, dtype) @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_shape_take_not_device(self):