mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Half support for addcmul, addcdiv, cumsum, and topk on CPU (#103319)
Add Half support for addcmul, addcdiv, cumsum, and topk on CPU. Note: This PR will introduce the issue https://github.com/pytorch/pytorch/issues/111454. Pull Request resolved: https://github.com/pytorch/pytorch/pull/103319 Approved by: https://github.com/jgong5, https://github.com/cpuhrsch
This commit is contained in:
@ -5,34 +5,34 @@
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
#include <c10/core/Scalar.h>
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
namespace at::native {
|
||||
namespace {
|
||||
|
||||
static void addcmul_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
ScalarType dtype = iter.common_dtype();
|
||||
if (dtype == kBFloat16) {
|
||||
float float_val = value.to<float>();
|
||||
auto float_vec = Vectorized<float>(float_val);
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](BFloat16 self_val, BFloat16 t1_val, BFloat16 t2_val) -> BFloat16 {
|
||||
return float(self_val) + float_val * float(t1_val) * float(t2_val);
|
||||
},
|
||||
[=](Vectorized<BFloat16> self_vec,
|
||||
Vectorized<BFloat16> t1_vec,
|
||||
Vectorized<BFloat16> t2_vec) {
|
||||
Vectorized<float> self_vec0, self_vec1;
|
||||
std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec);
|
||||
Vectorized<float> t1_vec0, t1_vec1, t2_vec0, t2_vec1;
|
||||
std::tie(t1_vec0, t1_vec1) = convert_bfloat16_float(t1_vec);
|
||||
std::tie(t2_vec0, t2_vec1) = convert_bfloat16_float(t2_vec);
|
||||
self_vec0 = self_vec0 + float_vec * t1_vec0 * t2_vec0;
|
||||
self_vec1 = self_vec1 + float_vec * t1_vec1 * t2_vec1;
|
||||
return convert_float_bfloat16(self_vec0, self_vec1);
|
||||
});
|
||||
if (at::isReducedFloatingType(dtype)) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "addcmul_cpu_out", [&]() {
|
||||
float float_val = value.to<float>();
|
||||
auto float_vec = Vectorized<float>(float_val);
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
|
||||
return float(self_val) + float_val * float(t1_val) * float(t2_val);
|
||||
},
|
||||
[=](Vectorized<scalar_t> self_vec,
|
||||
Vectorized<scalar_t> t1_vec,
|
||||
Vectorized<scalar_t> t2_vec) -> Vectorized<scalar_t> {
|
||||
auto [self_vec0, self_vec1] = convert_to_float<scalar_t>(self_vec);
|
||||
auto [t1_vec0, t1_vec1] = convert_to_float<scalar_t>(t1_vec);
|
||||
auto [t2_vec0, t2_vec1] = convert_to_float<scalar_t>(t2_vec);
|
||||
self_vec0 = self_vec0 + float_vec * t1_vec0 * t2_vec0;
|
||||
self_vec1 = self_vec1 + float_vec * t1_vec1 * t2_vec1;
|
||||
return convert_from_float<scalar_t>(self_vec0, self_vec1);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::ComplexHalf, at::ScalarType::Half,
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::ComplexHalf,
|
||||
dtype, "addcmul_cpu_out", [&] {
|
||||
scalar_t scalar_val = value.to<scalar_t>();
|
||||
auto scalar_vec = Vectorized<scalar_t>(scalar_val);
|
||||
@ -52,26 +52,26 @@ static void addcmul_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
|
||||
static void addcdiv_cpu_kernel(TensorIteratorBase& iter, const Scalar& value) {
|
||||
ScalarType dtype = iter.common_dtype();
|
||||
if (dtype == kBFloat16) {
|
||||
float float_val = value.to<float>();
|
||||
auto float_vec = Vectorized<float>(float_val);
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](BFloat16 self_val, BFloat16 t1_val, BFloat16 t2_val) -> BFloat16 {
|
||||
return float(self_val) + float_val * float(t1_val) / float(t2_val);
|
||||
},
|
||||
[=](Vectorized<BFloat16> self_vec,
|
||||
Vectorized<BFloat16> t1_vec,
|
||||
Vectorized<BFloat16> t2_vec) {
|
||||
Vectorized<float> self_vec0, self_vec1;
|
||||
std::tie(self_vec0, self_vec1) = convert_bfloat16_float(self_vec);
|
||||
Vectorized<float> t1_vec0, t1_vec1, t2_vec0, t2_vec1;
|
||||
std::tie(t1_vec0, t1_vec1) = convert_bfloat16_float(t1_vec);
|
||||
std::tie(t2_vec0, t2_vec1) = convert_bfloat16_float(t2_vec);
|
||||
self_vec0 = self_vec0 + float_vec * t1_vec0 / t2_vec0;
|
||||
self_vec1 = self_vec1 + float_vec * t1_vec1 / t2_vec1;
|
||||
return convert_float_bfloat16(self_vec0, self_vec1);
|
||||
});
|
||||
if (at::isReducedFloatingType(dtype)) {
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(dtype, "addcdiv_cpu_out", [&]() {
|
||||
float float_val = value.to<float>();
|
||||
auto float_vec = Vectorized<float>(float_val);
|
||||
cpu_kernel_vec(
|
||||
iter,
|
||||
[=](scalar_t self_val, scalar_t t1_val, scalar_t t2_val) -> scalar_t {
|
||||
return float(self_val) + float_val * float(t1_val) / float(t2_val);
|
||||
},
|
||||
[=](Vectorized<scalar_t> self_vec,
|
||||
Vectorized<scalar_t> t1_vec,
|
||||
Vectorized<scalar_t> t2_vec) -> Vectorized<scalar_t> {
|
||||
auto [self_vec0, self_vec1] = convert_to_float<scalar_t>(self_vec);
|
||||
auto [t1_vec0, t1_vec1] = convert_to_float<scalar_t>(t1_vec);
|
||||
auto [t2_vec0, t2_vec1] = convert_to_float<scalar_t>(t2_vec);
|
||||
self_vec0 = self_vec0 + float_vec * t1_vec0 / t2_vec0;
|
||||
self_vec1 = self_vec1 + float_vec * t1_vec1 / t2_vec1;
|
||||
return convert_from_float<scalar_t>(self_vec0, self_vec1);
|
||||
});
|
||||
});
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(dtype, "addcdiv_cpu_out", [&] {
|
||||
scalar_t scalar_val = value.to<scalar_t>();
|
||||
|
@ -81,7 +81,7 @@ static void cumsum_cpu_kernel(const Tensor& result, const Tensor& self, int64_t
|
||||
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
|
||||
int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, self.scalar_type(), "cumsum_out_cpu", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, self.scalar_type(), "cumsum_out_cpu", [&] {
|
||||
cpu_cum_base_kernel<scalar_t>(result, self, wrap_dim, [&] (
|
||||
scalar_t* result_data, auto result_dim_stride,
|
||||
const scalar_t* self_data, auto self_dim_stride, scalar_t init_val) {
|
||||
|
@ -223,7 +223,7 @@ static void topk_kernel(
|
||||
auto mode_indices_stride = indices.strides()[dim];
|
||||
auto tmp_values_stride = self.strides()[dim];
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND(ScalarType::BFloat16, self.scalar_type(), "topk_cpu", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "topk_cpu", [&] {
|
||||
auto loop = [&](char** data, const int64_t* strides, int64_t n) {
|
||||
if (self.scalar_type() == ScalarType::BFloat16) {
|
||||
return topk_impl_loop<scalar_t, float>(
|
||||
|
@ -371,6 +371,15 @@ EXPECTED_SKIPS_OR_FAILS: Tuple[onnx_test_common.DecorateMeta, ...] = (
|
||||
"cumsum", dtypes=onnx_test_common.BOOL_TYPES + (torch.uint8, torch.int8, torch.int16,),
|
||||
reason=onnx_test_common.reason_onnx_does_not_support("Cumsum", "bool, uint8, int8, int16")
|
||||
),
|
||||
# See https://github.com/pytorch/pytorch/issues/111454
|
||||
xfail(
|
||||
"cumsum", dtypes=(torch.float16,),
|
||||
reason=onnx_test_common.reason_onnx_runtime_does_not_support("RUNTIME_EXCEPTION : \
|
||||
Exception during initialization: /onnxruntime_src/onnxruntime/core/framework/\
|
||||
allocation_planner.cc:230 int& onnxruntime::PlannerImpl::\
|
||||
UseCount(onnxruntime::OrtValueIndex) n >= 0 && static_cast<size_t>(n) \
|
||||
< ort_value_info_.size() was false.")
|
||||
),
|
||||
xfail(
|
||||
"cross",
|
||||
reason=onnx_test_common.reason_onnx_script_does_not_support("linalg_cross"),
|
||||
|
@ -10898,7 +10898,7 @@ class TestConsistency(TestCaseMPS):
|
||||
# You most likely do NOT want to modify this manually
|
||||
|
||||
FP16_LOW_PRECISION_LIST = {
|
||||
'add', 'sub', 'div',
|
||||
'add', 'sub', 'div', 'addcdiv',
|
||||
'__rdiv__', '__rmul__',
|
||||
'nn.functional.huber_loss',
|
||||
'true_divide', 'kron',
|
||||
|
@ -754,9 +754,8 @@ class TestSortAndSelect(TestCase):
|
||||
for curr_size in (small, large, verylarge):
|
||||
self._test_topk_dtype(device, dtype, True, curr_size)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.bfloat16)
|
||||
def test_topk_bfloat16(self, device, dtype):
|
||||
@dtypes(torch.bfloat16, torch.half)
|
||||
def test_topk_lower_precision(self, device, dtype):
|
||||
|
||||
small = 10
|
||||
large = 4096
|
||||
@ -765,7 +764,7 @@ class TestSortAndSelect(TestCase):
|
||||
self._test_topk_dtype(device, dtype, False, curr_size)
|
||||
|
||||
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
|
||||
@dtypes(torch.float, torch.double, torch.bfloat16)
|
||||
@dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
|
||||
def test_topk_nonfinite(self, device, dtype):
|
||||
x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype)
|
||||
val, idx = x.topk(4)
|
||||
@ -796,7 +795,7 @@ class TestSortAndSelect(TestCase):
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypesIfCUDA(*all_types_and(torch.bfloat16))
|
||||
@dtypes(*all_types())
|
||||
@dtypes(*all_types_and(torch.bfloat16, torch.half))
|
||||
def test_topk_zero(self, device, dtype):
|
||||
# https://github.com/pytorch/pytorch/issues/49205
|
||||
t = torch.rand(2, 2, device=device).to(dtype=dtype)
|
||||
|
@ -10164,8 +10164,7 @@ op_db: List[OpInfo] = [
|
||||
reference_inputs_func=partial(
|
||||
reference_inputs_elementwise_ternary, sample_inputs_func=reference_inputs_addcmul_addcdiv)),
|
||||
OpInfo('addcdiv',
|
||||
dtypes=floating_and_complex_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
|
||||
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
@ -10788,8 +10787,7 @@ op_db: List[OpInfo] = [
|
||||
supports_out=True,
|
||||
supports_forward_ad=True),
|
||||
OpInfo('cumsum',
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
@ -14022,8 +14020,7 @@ op_db: List[OpInfo] = [
|
||||
),
|
||||
),
|
||||
OpInfo('topk',
|
||||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.bfloat16, torch.float16),
|
||||
dtypes=all_types_and(torch.bfloat16, torch.float16),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_jit_shape_analysis=True,
|
||||
@ -17033,8 +17030,7 @@ op_db: List[OpInfo] = [
|
||||
check_batched_forward_grad=False,
|
||||
sample_inputs_func=sample_trapezoid),
|
||||
OpInfo('cumulative_trapezoid',
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.float16),
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16),
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# See https://github.com/pytorch/pytorch/pull/78358
|
||||
|
@ -561,8 +561,7 @@ op_db: List[OpInfo] = [
|
||||
),
|
||||
OpInfo(
|
||||
"masked.cumsum",
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16),
|
||||
method_variant=None,
|
||||
# Runs very slowly on slow gradcheck - alternatively reduce input sizes
|
||||
gradcheck_fast_mode=True,
|
||||
|
Reference in New Issue
Block a user