Add FP8 support for eye (#139974)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139974
Approved by: https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
Jiang, Yanbing
2024-12-23 05:05:16 +00:00
committed by PyTorch MergeBot
parent 448c16ac87
commit 01890526b9
4 changed files with 19 additions and 10 deletions

View File

@ -663,8 +663,10 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {
result.zero_();
int64_t sz = std::min<int64_t>(n, m);
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBFloat16, kHalf, kBool, result.scalar_type(), "eye", [&]() -> void {
AT_DISPATCH_V2(
result.scalar_type(),
"eye",
[&]() -> void {
scalar_t* result_data = result.data_ptr<scalar_t>();
at::parallel_for(
0, sz, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
@ -672,8 +674,12 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {
result_data[i * (result.strides()[0] + result.strides()[1])] =
1;
});
});
},
kBFloat16,
kHalf,
kBool,
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES));
return result;
}

View File

@ -213,14 +213,15 @@ static void aminmax_kernel(
}
static void where_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool,
AT_DISPATCH_V2(
iter.dtype(), "where_cpu", [&] {
cpu_kernel(
iter,
[=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
});
},
kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES));
}
static void isposinf_kernel_impl(TensorIteratorBase& iter) {

View File

@ -1,6 +1,7 @@
#define TORCH_ASSERT_NO_OPERATORS
#include <ATen/NumericUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/cuda/Loops.cuh>
@ -12,13 +13,14 @@ namespace at::native {
namespace {
void where_kernel_impl(TensorIterator &iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] {
AT_DISPATCH_V2(iter.dtype(), "where_cuda", [&] {
gpu_kernel(
iter,
[=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
return cond_val ? self_val : other_val;
});
});
},
kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES));
}
void isposinf_kernel_impl(TensorIteratorBase &iter) {

View File

@ -21,7 +21,7 @@ from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
empty_types, complex_types_and, integral_types, custom_types,
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
)
from torch.testing._internal.common_device_type import \
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
@ -19154,7 +19154,7 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
)),
OpInfo('eye',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
dtypes=all_types_complex_float8_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_eye,
error_inputs_func=error_inputs_eye,
supports_out=True,