mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add FP8 support for eye (#139974)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139974 Approved by: https://github.com/jgong5, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
448c16ac87
commit
01890526b9
@ -663,8 +663,10 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {
|
||||
result.zero_();
|
||||
|
||||
int64_t sz = std::min<int64_t>(n, m);
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
kBFloat16, kHalf, kBool, result.scalar_type(), "eye", [&]() -> void {
|
||||
AT_DISPATCH_V2(
|
||||
result.scalar_type(),
|
||||
"eye",
|
||||
[&]() -> void {
|
||||
scalar_t* result_data = result.data_ptr<scalar_t>();
|
||||
at::parallel_for(
|
||||
0, sz, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
|
||||
@ -672,8 +674,12 @@ Tensor& eye_out_cpu(int64_t n, int64_t m, Tensor& result) {
|
||||
result_data[i * (result.strides()[0] + result.strides()[1])] =
|
||||
1;
|
||||
});
|
||||
});
|
||||
|
||||
},
|
||||
kBFloat16,
|
||||
kHalf,
|
||||
kBool,
|
||||
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
|
||||
AT_EXPAND(AT_FLOAT8_TYPES));
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@ -213,14 +213,15 @@ static void aminmax_kernel(
|
||||
}
|
||||
|
||||
static void where_kernel_impl(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool,
|
||||
AT_DISPATCH_V2(
|
||||
iter.dtype(), "where_cpu", [&] {
|
||||
cpu_kernel(
|
||||
iter,
|
||||
[=](bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
return cond_val ? self_val : other_val;
|
||||
});
|
||||
});
|
||||
},
|
||||
kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES));
|
||||
}
|
||||
|
||||
static void isposinf_kernel_impl(TensorIteratorBase& iter) {
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#define TORCH_ASSERT_NO_OPERATORS
|
||||
#include <ATen/NumericUtils.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Dispatch_v2.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <ATen/native/TensorCompare.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
@ -12,13 +13,14 @@ namespace at::native {
|
||||
namespace {
|
||||
|
||||
void where_kernel_impl(TensorIterator &iter) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_cuda", [&] {
|
||||
AT_DISPATCH_V2(iter.dtype(), "where_cuda", [&] {
|
||||
gpu_kernel(
|
||||
iter,
|
||||
[=] GPU_LAMBDA (bool cond_val, scalar_t self_val, scalar_t other_val) -> scalar_t {
|
||||
return cond_val ? self_val : other_val;
|
||||
});
|
||||
});
|
||||
},
|
||||
kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_FLOAT8_TYPES));
|
||||
}
|
||||
|
||||
void isposinf_kernel_impl(TensorIteratorBase &iter) {
|
||||
|
||||
@ -21,7 +21,7 @@ from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import (
|
||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||
floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and,
|
||||
empty_types, complex_types_and, integral_types, custom_types,
|
||||
empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
|
||||
@ -19154,7 +19154,7 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
|
||||
)),
|
||||
OpInfo('eye',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
dtypes=all_types_complex_float8_and(torch.bool, torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_eye,
|
||||
error_inputs_func=error_inputs_eye,
|
||||
supports_out=True,
|
||||
|
||||
Reference in New Issue
Block a user