mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add fp8 support to index_cuda (#144747)
Fixes #133605 **Summary** This PR adds support for FP8 data types to the `index_cuda` op. It uses `AT_DISPATCH_V2` which is a new macro that can handle arbitrary number of dtypes, as opposed to the old implementations which had a separate macro for each possible number of dtype arguments (e.g. `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3,4,5...}`). **Test plan** Updated test `index_cuda_with_cpu` in `test/test_fake_tensor.py` to have cases for all dtypes handled by `index_cuda`, including fp8 dtypes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144747 Approved by: https://github.com/vkuzo
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e4b8592a3
commit
d02c396fbb
@ -193,11 +193,23 @@ void index_put_kernel_impl(TensorIterator& iter, const IntArrayRef index_size, c
|
||||
});
|
||||
}
|
||||
|
||||
static void index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, const IntArrayRef index_stride) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16, iter.dtype(), "index_cuda", [&] {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
index_kernel_impl<dtype>(iter, index_size, index_stride);
|
||||
});
|
||||
static void index_kernel(
|
||||
TensorIteratorBase& iter,
|
||||
const IntArrayRef index_size,
|
||||
const IntArrayRef index_stride) {
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(),
|
||||
"index_cuda",
|
||||
AT_WRAP([&] {
|
||||
using dtype = OpaqueType<sizeof(scalar_t)>;
|
||||
index_kernel_impl<dtype>(iter, index_size, index_stride);
|
||||
}),
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES),
|
||||
kComplexHalf,
|
||||
kHalf,
|
||||
kBool,
|
||||
kBFloat16);
|
||||
}
|
||||
|
||||
static void index_fill_kernel(
|
||||
|
||||
@ -66,6 +66,7 @@ from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
xfailIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_complex_float8_and
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
@ -159,11 +160,16 @@ class FakeTensorTest(TestCase):
|
||||
TEST_WITH_TORCHDYNAMO, "isinstance check for FakeTensor won't work with compile"
|
||||
)
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_index_cuda_with_cpu(self):
|
||||
@parametrize(
|
||||
"dtype",
|
||||
all_types_complex_float8_and(),
|
||||
)
|
||||
def test_index_cuda_with_cpu(self, dtype):
|
||||
with FakeTensorMode():
|
||||
x = torch.rand([2048], device="cuda")
|
||||
x = torch.ones([2048], device="cuda", dtype=dtype)
|
||||
out = x[torch.zeros([36], dtype=torch.int64)]
|
||||
self.checkType(out, "cuda", [36])
|
||||
self.assertEqual(out.dtype, dtype)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_shape_take_not_device(self):
|
||||
|
||||
Reference in New Issue
Block a user