mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add Half for sparse.mm reduce (#133672)
This PR is to add Half support for sparse.mm reduce in CPU backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133672 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c6fbae579
commit
215b14530a
@ -434,7 +434,7 @@ void spmm_reduce_kernel(
|
|||||||
const Tensor& values,
|
const Tensor& values,
|
||||||
const Tensor& other,
|
const Tensor& other,
|
||||||
ReductionType reduce_op) {
|
ReductionType reduce_op) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
||||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
|
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
|
||||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||||
spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
|
spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
|
||||||
@ -452,7 +452,7 @@ void spmm_reduce_arg_kernel(
|
|||||||
const Tensor& values,
|
const Tensor& values,
|
||||||
const Tensor& other,
|
const Tensor& other,
|
||||||
ReductionType reduce_op) {
|
ReductionType reduce_op) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
||||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
|
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
|
||||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||||
spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
|
spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
|
||||||
@ -471,7 +471,7 @@ void spmm_reduce_backward_input_kernel(
|
|||||||
const Tensor& row_indices,
|
const Tensor& row_indices,
|
||||||
ReductionType reduce_op) {
|
ReductionType reduce_op) {
|
||||||
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
|
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
|
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
|
||||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
|
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
|
||||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||||
spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
|
spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
|
||||||
@ -489,7 +489,7 @@ void spmm_reduce_backward_input_arg_kernel(
|
|||||||
const Tensor& arg_out,
|
const Tensor& arg_out,
|
||||||
ReductionType reduce_op) {
|
ReductionType reduce_op) {
|
||||||
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
|
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
|
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
|
||||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
|
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
|
||||||
spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
|
spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
|
||||||
grad_self, grad_out, col_indices, other, arg_out);
|
grad_self, grad_out, col_indices, other, arg_out);
|
||||||
@ -502,7 +502,7 @@ void spmm_reduce_normalize_values_kernel(
|
|||||||
const Tensor& values,
|
const Tensor& values,
|
||||||
const Tensor& crow_indices,
|
const Tensor& crow_indices,
|
||||||
const Tensor& row_indices) {
|
const Tensor& row_indices) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
|
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
|
||||||
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
|
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
|
||||||
spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
|
spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
|
||||||
normalized_values, values, crow_indices, row_indices);
|
normalized_values, values, crow_indices, row_indices);
|
||||||
@ -545,7 +545,7 @@ void spmm_reduce_backward_other_arg_kernel(
|
|||||||
const Tensor& arg_out,
|
const Tensor& arg_out,
|
||||||
ReductionType reduce_op) {
|
ReductionType reduce_op) {
|
||||||
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
|
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
|
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
|
||||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
|
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
|
||||||
spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
|
spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
|
||||||
grad_other, grad_out, col_indices, values, arg_out);
|
grad_other, grad_out, col_indices, values, arg_out);
|
||||||
|
@ -53,7 +53,7 @@ inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper struct for bfloat16 vectorization
|
// Helper struct for bfloat16/float16 vectorization
|
||||||
// Useful when you need float as immediate dtype or accumulate dtype
|
// Useful when you need float as immediate dtype or accumulate dtype
|
||||||
using namespace vec;
|
using namespace vec;
|
||||||
struct Vec2 {
|
struct Vec2 {
|
||||||
@ -64,6 +64,10 @@ struct Vec2 {
|
|||||||
auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
|
auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
|
||||||
return {v0, v1};
|
return {v0, v1};
|
||||||
}
|
}
|
||||||
|
static Vec2 loadu(const Half* ptr) {
|
||||||
|
auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
|
||||||
|
return {v0, v1};
|
||||||
|
}
|
||||||
static Vec2 loadu(const float* ptr) {
|
static Vec2 loadu(const float* ptr) {
|
||||||
return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
|
return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
|
||||||
}
|
}
|
||||||
@ -71,6 +75,10 @@ struct Vec2 {
|
|||||||
Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
|
Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
|
||||||
val.store(ptr);
|
val.store(ptr);
|
||||||
}
|
}
|
||||||
|
void store(Half* ptr) const {
|
||||||
|
Vectorized<Half> val = convert_float_half(val0, val1);
|
||||||
|
val.store(ptr);
|
||||||
|
}
|
||||||
void store(float* ptr) const {
|
void store(float* ptr) const {
|
||||||
val0.store(ptr);
|
val0.store(ptr);
|
||||||
val1.store(ptr + Vectorized<float>::size());
|
val1.store(ptr + Vectorized<float>::size());
|
||||||
@ -85,6 +93,7 @@ inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0,
|
|||||||
|
|
||||||
template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
|
template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
|
||||||
template <> struct VectorizedType<BFloat16> { using type = Vec2; };
|
template <> struct VectorizedType<BFloat16> { using type = Vec2; };
|
||||||
|
template <> struct VectorizedType<Half> { using type = Vec2; };
|
||||||
template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
|
template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
|
||||||
|
|
||||||
// Helper for mixed data type parameter Vec::load
|
// Helper for mixed data type parameter Vec::load
|
||||||
|
@ -226,7 +226,7 @@ inductor_expected_failures_single_sample["cpu"] = {
|
|||||||
("normal", "in_place"): {f16, f32, f64},
|
("normal", "in_place"): {f16, f32, f64},
|
||||||
("normal", "number_mean"): {f16, f32, f64},
|
("normal", "number_mean"): {f16, f32, f64},
|
||||||
"normal": {f16, f32, f64},
|
"normal": {f16, f32, f64},
|
||||||
("sparse.mm", "reduce"): {f32, f64},
|
("sparse.mm", "reduce"): {f32, f64, f16},
|
||||||
"sparse.sampled_addmm": {f32, f64},
|
"sparse.sampled_addmm": {f32, f64},
|
||||||
"to_sparse": {
|
"to_sparse": {
|
||||||
f32,
|
f32,
|
||||||
|
@ -2575,7 +2575,7 @@ class TestSparseCSR(TestCase):
|
|||||||
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
|
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||||
@precisionOverride({torch.bfloat16: 0.01})
|
@precisionOverride({torch.bfloat16: 0.01})
|
||||||
def test_sparse_mm_reduce_sum(self, device, dtype):
|
def test_sparse_mm_reduce_sum(self, device, dtype):
|
||||||
def run_test(m, n, k, nnz, train):
|
def run_test(m, n, k, nnz, train):
|
||||||
@ -2613,8 +2613,8 @@ class TestSparseCSR(TestCase):
|
|||||||
|
|
||||||
@skipIfTorchDynamo()
|
@skipIfTorchDynamo()
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||||
@precisionOverride({torch.bfloat16: 0.01})
|
@precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
|
||||||
def test_sparse_mm_reduce(self, device, dtype):
|
def test_sparse_mm_reduce(self, device, dtype):
|
||||||
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
|
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
|
||||||
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
|
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
|
||||||
@ -2649,7 +2649,7 @@ class TestSparseCSR(TestCase):
|
|||||||
out = torch.sparse.mm(csr, mat, reduce_type)
|
out = torch.sparse.mm(csr, mat, reduce_type)
|
||||||
self.assertEqual(out, ref_out)
|
self.assertEqual(out, ref_out)
|
||||||
|
|
||||||
if train and dtype is not torch.bfloat16:
|
if train and dtype not in (torch.bfloat16, torch.float16):
|
||||||
ref_out.sum().backward()
|
ref_out.sum().backward()
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
|
|
||||||
|
@ -13576,7 +13576,7 @@ op_db: List[OpInfo] = [
|
|||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
|
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
|
||||||
)),
|
)),
|
||||||
OpInfo('sparse.mm',
|
OpInfo('sparse.mm',
|
||||||
dtypes=floating_types_and(torch.bfloat16),
|
dtypes=floating_types_and(torch.bfloat16, torch.float16),
|
||||||
variant_test_name='reduce',
|
variant_test_name='reduce',
|
||||||
supports_autograd=True,
|
supports_autograd=True,
|
||||||
supports_out=False,
|
supports_out=False,
|
||||||
|
Reference in New Issue
Block a user