mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Half for sparse.mm reduce (#133672)
This PR is to add Half support for sparse.mm reduce in CPU backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133672 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c6fbae579
commit
215b14530a
@ -434,7 +434,7 @@ void spmm_reduce_kernel(
|
||||
const Tensor& values,
|
||||
const Tensor& other,
|
||||
ReductionType reduce_op) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
|
||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||
spmm_reduce_kernel_impl<scalar_t, index_t, reduce>(
|
||||
@ -452,7 +452,7 @@ void spmm_reduce_arg_kernel(
|
||||
const Tensor& values,
|
||||
const Tensor& other,
|
||||
ReductionType reduce_op) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_indices", [&]() {
|
||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||
spmm_reduce_arg_kernel_impl<scalar_t, index_t, reduce>(
|
||||
@ -471,7 +471,7 @@ void spmm_reduce_backward_input_kernel(
|
||||
const Tensor& row_indices,
|
||||
ReductionType reduce_op) {
|
||||
TORCH_CHECK(reduce_op == ReductionType::SUM || reduce_op == ReductionType::MEAN);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_indices", [&]() {
|
||||
AT_DISPATCH_REDUCTION_TYPES(reduce_op, [&]() {
|
||||
spmm_reduce_backward_input_kernel_impl<scalar_t, index_t, reduce>(
|
||||
@ -489,7 +489,7 @@ void spmm_reduce_backward_input_arg_kernel(
|
||||
const Tensor& arg_out,
|
||||
ReductionType reduce_op) {
|
||||
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, other.scalar_type(), "spmm_reduce_backward_input_arg_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_input_arg_indices", [&]() {
|
||||
spmm_reduce_backward_input_arg_kernel_impl<scalar_t, index_t>(
|
||||
grad_self, grad_out, col_indices, other, arg_out);
|
||||
@ -502,7 +502,7 @@ void spmm_reduce_normalize_values_kernel(
|
||||
const Tensor& values,
|
||||
const Tensor& crow_indices,
|
||||
const Tensor& row_indices) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_normalize_values_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "spmm_reduce_normalize_values_indices", [&]() {
|
||||
spmm_reduce_normalize_values_kernel_impl<scalar_t, index_t>(
|
||||
normalized_values, values, crow_indices, row_indices);
|
||||
@ -545,7 +545,7 @@ void spmm_reduce_backward_other_arg_kernel(
|
||||
const Tensor& arg_out,
|
||||
ReductionType reduce_op) {
|
||||
TORCH_CHECK(reduce_op == ReductionType::MAX || reduce_op == ReductionType::MIN);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, values.scalar_type(), "spmm_reduce_backward_other_arg_kernel", [&]() {
|
||||
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "spmm_reduce_backward_other_arg_indices", [&]() {
|
||||
spmm_reduce_backward_other_arg_kernel_impl<scalar_t, index_t>(
|
||||
grad_other, grad_out, col_indices, values, arg_out);
|
||||
|
@ -53,7 +53,7 @@ inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Helper struct for bfloat16 vectorization
|
||||
// Helper struct for bfloat16/float16 vectorization
|
||||
// Useful when you need float as immediate dtype or accumulate dtype
|
||||
using namespace vec;
|
||||
struct Vec2 {
|
||||
@ -64,6 +64,10 @@ struct Vec2 {
|
||||
auto [v0, v1] = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
|
||||
return {v0, v1};
|
||||
}
|
||||
static Vec2 loadu(const Half* ptr) {
|
||||
auto [v0, v1] = convert_half_float(Vectorized<Half>::loadu(ptr));
|
||||
return {v0, v1};
|
||||
}
|
||||
static Vec2 loadu(const float* ptr) {
|
||||
return {Vectorized<float>::loadu(ptr), Vectorized<float>::loadu(ptr + Vectorized<float>::size())};
|
||||
}
|
||||
@ -71,6 +75,10 @@ struct Vec2 {
|
||||
Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
|
||||
val.store(ptr);
|
||||
}
|
||||
void store(Half* ptr) const {
|
||||
Vectorized<Half> val = convert_float_half(val0, val1);
|
||||
val.store(ptr);
|
||||
}
|
||||
void store(float* ptr) const {
|
||||
val0.store(ptr);
|
||||
val1.store(ptr + Vectorized<float>::size());
|
||||
@ -85,6 +93,7 @@ inline Vec2 minimum(const Vec2& a, const Vec2& b) { return {vec::minimum(a.val0,
|
||||
|
||||
template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
|
||||
template <> struct VectorizedType<BFloat16> { using type = Vec2; };
|
||||
template <> struct VectorizedType<Half> { using type = Vec2; };
|
||||
template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
|
||||
|
||||
// Helper for mixed data type parameter Vec::load
|
||||
|
@ -226,7 +226,7 @@ inductor_expected_failures_single_sample["cpu"] = {
|
||||
("normal", "in_place"): {f16, f32, f64},
|
||||
("normal", "number_mean"): {f16, f32, f64},
|
||||
"normal": {f16, f32, f64},
|
||||
("sparse.mm", "reduce"): {f32, f64},
|
||||
("sparse.mm", "reduce"): {f32, f64, f16},
|
||||
"sparse.sampled_addmm": {f32, f64},
|
||||
"to_sparse": {
|
||||
f32,
|
||||
|
@ -2575,7 +2575,7 @@ class TestSparseCSR(TestCase):
|
||||
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||
@precisionOverride({torch.bfloat16: 0.01})
|
||||
def test_sparse_mm_reduce_sum(self, device, dtype):
|
||||
def run_test(m, n, k, nnz, train):
|
||||
@ -2613,8 +2613,8 @@ class TestSparseCSR(TestCase):
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@precisionOverride({torch.bfloat16: 0.01})
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||
@precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
|
||||
def test_sparse_mm_reduce(self, device, dtype):
|
||||
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
|
||||
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
|
||||
@ -2649,7 +2649,7 @@ class TestSparseCSR(TestCase):
|
||||
out = torch.sparse.mm(csr, mat, reduce_type)
|
||||
self.assertEqual(out, ref_out)
|
||||
|
||||
if train and dtype is not torch.bfloat16:
|
||||
if train and dtype not in (torch.bfloat16, torch.float16):
|
||||
ref_out.sum().backward()
|
||||
out.sum().backward()
|
||||
|
||||
|
@ -13576,7 +13576,7 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
|
||||
)),
|
||||
OpInfo('sparse.mm',
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypes=floating_types_and(torch.bfloat16, torch.float16),
|
||||
variant_test_name='reduce',
|
||||
supports_autograd=True,
|
||||
supports_out=False,
|
||||
|
Reference in New Issue
Block a user