Add Half for sparse.mm reduce (#133672)

This PR is to add Half support for sparse.mm reduce in CPU backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133672
Approved by: https://github.com/Skylion007
This commit is contained in:
Jiang, Yanbing
2024-08-17 15:20:39 +00:00
committed by PyTorch MergeBot
parent 1c6fbae579
commit 215b14530a
5 changed files with 22 additions and 13 deletions

View File

@ -434,7 +434,7 @@ void spmm_reduce_kernel(
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
@ -452,7 +452,7 @@ void spmm_reduce_arg_kernel(
const Tensor& values,
const Tensor& other,
ReductionType reduce_op) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
@ -471,7 +471,7 @@ void spmm_reduce_backward_input_kernel(
const Tensor& row_indices,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
@ -489,7 +489,7 @@ void spmm_reduce_backward_input_arg_kernel(
const Tensor& arg_out,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
grad_self, grad_out, col_indices, other, arg_out);
@ -502,7 +502,7 @@ void spmm_reduce_normalize_values_kernel(
const Tensor& values,
const Tensor& crow_indices,
const Tensor& row_indices) {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
normalized_values, values, crow_indices, row_indices);
@ -545,7 +545,7 @@ void spmm_reduce_backward_other_arg_kernel(
const Tensor& arg_out,
ReductionType reduce_op) {
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
grad_other, grad_out, col_indices, values, arg_out);

View File

@ -53,7 +53,7 @@ inline bool data_index_step(T& x, const T& X, Args&&... args) {
return false;
}
// Helper struct for bfloat16 vectorization
// Helper struct for bfloat16/float16 vectorization
// Useful when you need float as immediate dtype or accumulate dtype
using namespace vec;
struct Vec2 {
@ -64,6 +64,10 @@ struct Vec2 {
auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
return {v0, v1};
}
static Vec2 loadu(const Half* ptr) {
auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
return {v0, v1};
}
static Vec2 loadu(const float* ptr) {
return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
}
@ -71,6 +75,10 @@ struct Vec2 {
Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
val.store(ptr);
}
void store(Half* ptr) const {
Vectorized<Half> val = convert_float_half(val0, val1);
val.store(ptr);
}
void store(float* ptr) const {
val0.store(ptr);
val1.store(ptr + Vectorized<float>::size());
@ -85,6 +93,7 @@ inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0,
template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
template <> struct VectorizedType<BFloat16> { using type = Vec2; };
template <> struct VectorizedType<Half> { using type = Vec2; };
template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
// Helper for mixed data type parameter Vec::load

View File

@ -226,7 +226,7 @@ inductor_expected_failures_single_sample["cpu"] = {
("normal", "in_place"): {f16, f32, f64},
("normal", "number_mean"): {f16, f32, f64},
"normal": {f16, f32, f64},
("sparse.mm", "reduce"): {f32, f64},
("sparse.mm", "reduce"): {f32, f64, f16},
"sparse.sampled_addmm": {f32, f64},
"to_sparse": {
f32,

View File

@ -2575,7 +2575,7 @@ class TestSparseCSR(TestCase):
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
@precisionOverride({torch.bfloat16: 0.01})
def test_sparse_mm_reduce_sum(self, device, dtype):
def run_test(m, n, k, nnz, train):
@ -2613,8 +2613,8 @@ class TestSparseCSR(TestCase):
@skipIfTorchDynamo()
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@precisionOverride({torch.bfloat16: 0.01})
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
@precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
def test_sparse_mm_reduce(self, device, dtype):
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
@ -2649,7 +2649,7 @@ class TestSparseCSR(TestCase):
out = torch.sparse.mm(csr, mat, reduce_type)
self.assertEqual(out, ref_out)
if train and dtype is not torch.bfloat16:
if train and dtype not in (torch.bfloat16, torch.float16):
ref_out.sum().backward()
out.sum().backward()

View File

@ -13576,7 +13576,7 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
)),
OpInfo('sparse.mm',
dtypes=floating_types_and(torch.bfloat16),
dtypes=floating_types_and(torch.bfloat16, torch.float16),
variant_test_name='reduce',
supports_autograd=True,
supports_out=False,