min-jean-cho
2023-01-11 23:23:42 +00:00
committed by PyTorch MergeBot
parent f40777e4ad
commit af242eedfb
4 changed files with 12 additions and 3 deletions

View File

@ -347,6 +347,7 @@ class TestSparse(TestSparseBase):
lambda: self.sparse_tensor(indices, values, torch.Size([2, 4, 2, 1])))
@dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_to_dense(self, device, dtype):
def test_tensor(x, res):
x.to_dense() # Tests triple to_dense for memory corruption
@ -479,6 +480,7 @@ class TestSparse(TestSparseBase):
self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x))
@dtypes(torch.double, torch.cdouble)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_to_dense_hybrid(self, device, dtype):
def test_tensor(x, res):
x.to_dense() # Tests double to_dense for memory corruption
@ -832,6 +834,7 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_permute(self, device, dtype, coalesced):
# trivial checks
s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
@ -1488,6 +1491,7 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(torch.double)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_sparse_mm(self, device, dtype, coalesced):
def test_shape(d1, d2, d3, nnz, transposed):
if transposed:
@ -1509,6 +1513,7 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(torch.double)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
def test_sparse_mul(self, device, dtype, coalesced):
# https://github.com/pytorch/pytorch/issues/79914
a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True)