Add Half for sparse.mm reduce (#133672)

This PR is to add Half support for sparse.mm reduce in CPU backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133672
Approved by: https://github.com/Skylion007
This commit is contained in:
Jiang, Yanbing
2024-08-17 15:20:39 +00:00
committed by PyTorch MergeBot
parent 1c6fbae579
commit 215b14530a
5 changed files with 22 additions and 13 deletions

View File

@ -2575,7 +2575,7 @@ class TestSparseCSR(TestCase):
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
@precisionOverride({torch.bfloat16: 0.01})
def test_sparse_mm_reduce_sum(self, device, dtype):
def run_test(m, n, k, nnz, train):
@ -2613,8 +2613,8 @@ class TestSparseCSR(TestCase):
@skipIfTorchDynamo()
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
@precisionOverride({torch.bfloat16: 0.01})
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
@precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
def test_sparse_mm_reduce(self, device, dtype):
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
@ -2649,7 +2649,7 @@ class TestSparseCSR(TestCase):
out = torch.sparse.mm(csr, mat, reduce_type)
self.assertEqual(out, ref_out)
if train and dtype is not torch.bfloat16:
if train and dtype not in (torch.bfloat16, torch.float16):
ref_out.sum().backward()
out.sum().backward()