add fp8 support to index_cuda (#144747)

Fixes #133605

**Summary**

This PR adds support for FP8 data types to the `index_cuda` op.

It uses `AT_DISPATCH_V2` which is a new macro that can handle arbitrary number of dtypes, as opposed to the old implementations which had a separate macro for each possible number of dtype arguments (e.g. `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3,4,5...}`).

**Test plan**

Updated test `index_cuda_with_cpu` in `test/test_fake_tensor.py` to have cases for all dtypes handled by `index_cuda`, including fp8 dtypes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144747
Approved by: https://github.com/vkuzo
This commit is contained in:
Daniel Vega-Myhre
2025-01-17 22:53:21 +00:00
committed by PyTorch MergeBot
parent 4e4b8592a3
commit d02c396fbb
2 changed files with 25 additions and 7 deletions

View File

@ -193,11 +193,23 @@ void index_put_kernel_impl(TensorIterator& iter, const IntArrayRef index_size, c
});
}
static void index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, iter.dtype(), "index_cuda", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_kernel_impl<dtype>(iter, index_size, index_stride);
});
static void index_kernel(
TensorIteratorBase& iter,
const IntArrayRef index_size,
const IntArrayRef index_stride) {
AT_DISPATCH_V2(
iter.dtype(),
"index_cuda",
AT_WRAP([&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_kernel_impl<dtype>(iter, index_size, index_stride);
}),
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES),
kComplexHalf,
kHalf,
kBool,
kBFloat16);
}
static void index_fill_kernel(

View File

@ -66,6 +66,7 @@ from torch.testing._internal.common_utils import (
TestCase,
xfailIfTorchDynamo,
)
from torch.testing._internal.common_dtype import all_types_complex_float8_and
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.inductor_utils import GPU_TYPE
@ -159,11 +160,16 @@ class FakeTensorTest(TestCase):
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_index_cuda_with_cpu(self):
@parametrize(
"dtype",
all_types_complex_float8_and(),
)
def test_index_cuda_with_cpu(self, dtype):
with FakeTensorMode():
x = torch.rand([2048], device="cuda")
x = torch.ones([2048], device="cuda", dtype=dtype)
out = x[torch.zeros([36], dtype=torch.int64)]
self.checkType(out, "cuda", [36])
self.assertEqual(out.dtype, dtype)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_shape_take_not_device(self):