mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[MPS] mps sparse mul op implementation (#162349)
Implements mps sparse mul operation as well as enables other operations such as: 1. copy_ 2. div 3. sum 4. floor 5. power 6. sub 7. floor_divide Pull Request resolved: https://github.com/pytorch/pytorch/pull/162349 Approved by: https://github.com/pearu, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
be3b8d2ec9
commit
3ea6868049
@ -1108,8 +1108,8 @@ class TestSparse(TestSparseBase):
|
||||
test_shape(2, 20, [3, 17, 19, 5])
|
||||
test_shape(2, 20, [3, 17, 19, 0])
|
||||
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_add_sub_nnz(self, device, dtype):
|
||||
# nnz should not grow unbounded (gh-34964)
|
||||
x = torch.randn(10, dtype=dtype, device=device).to_sparse()
|
||||
@ -1687,8 +1687,8 @@ class TestSparse(TestSparseBase):
|
||||
test_shape(7, 8, 9, 20, True)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
|
||||
@gradcheck_semantics()
|
||||
def test_sparse_mul(self, device, dtype, coalesced, gradcheck):
|
||||
@ -1868,8 +1868,8 @@ class TestSparse(TestSparseBase):
|
||||
x.norm(**kwargs)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
|
||||
def test_sparse_sum(self, device, dtype, coalesced):
|
||||
|
||||
@ -1933,7 +1933,6 @@ class TestSparse(TestSparseBase):
|
||||
S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
|
||||
run_tests(S.requires_grad_(True), test_dim)
|
||||
|
||||
@expectedFailureMPS
|
||||
def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced):
|
||||
shape = shape_i + (shape_v)
|
||||
x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced)
|
||||
@ -2011,6 +2010,7 @@ class TestSparse(TestSparseBase):
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
def test_basic_ops(self, device, dtype, coalesced):
|
||||
|
||||
def _test_basic_ops():
|
||||
|
Reference in New Issue
Block a user