mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Half for sparse.mm reduce (#133672)
This PR is to add Half support for sparse.mm reduce in CPU backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133672 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
1c6fbae579
commit
215b14530a
@ -2575,7 +2575,7 @@ class TestSparseCSR(TestCase):
|
||||
torch.sparse.sampled_addmm(a_sparse, a, a_sparse)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||
@precisionOverride({torch.bfloat16: 0.01})
|
||||
def test_sparse_mm_reduce_sum(self, device, dtype):
|
||||
def run_test(m, n, k, nnz, train):
|
||||
@ -2613,8 +2613,8 @@ class TestSparseCSR(TestCase):
|
||||
|
||||
@skipIfTorchDynamo()
|
||||
@onlyCPU
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
||||
@precisionOverride({torch.bfloat16: 0.01})
|
||||
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
||||
@precisionOverride({torch.bfloat16: 0.01, torch.float16: 0.01})
|
||||
def test_sparse_mm_reduce(self, device, dtype):
|
||||
def run_test(m, n, k, nnz, reduce_type, index_dtype, train):
|
||||
csr = self.genSparseCSRTensor((m, n), nnz, dtype=dtype, device=device, index_dtype=index_dtype)
|
||||
@ -2649,7 +2649,7 @@ class TestSparseCSR(TestCase):
|
||||
out = torch.sparse.mm(csr, mat, reduce_type)
|
||||
self.assertEqual(out, ref_out)
|
||||
|
||||
if train and dtype is not torch.bfloat16:
|
||||
if train and dtype not in (torch.bfloat16, torch.float16):
|
||||
ref_out.sum().backward()
|
||||
out.sum().backward()
|
||||
|
||||
|
Reference in New Issue
Block a user