mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] zeros like, narrow and enable tests (#163011)
zeros like, narrow and enable tests for SparseMPS Pull Request resolved: https://github.com/pytorch/pytorch/pull/163011 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
559e8d1c20
commit
6db37d7206
@ -4372,7 +4372,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU: narrow_copy_dense_cpu
|
||||
SparseCPU, SparseCUDA: narrow_copy_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: narrow_copy_sparse
|
||||
CompositeExplicitAutogradNonFunctional: narrow_copy_dense_symint
|
||||
tags: view_copy
|
||||
|
||||
@ -6660,7 +6660,7 @@
|
||||
- func: zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: zeros_out
|
||||
SparseCPU, SparseCUDA, SparseMeta: zeros_sparse_out
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: zeros_sparse_out
|
||||
|
||||
- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
dispatch:
|
||||
|
@ -479,8 +479,8 @@ class TestSparse(TestSparseBase):
|
||||
"cannot set is_coalesced to true if indices correspond to uncoalesced COO tensor"):
|
||||
torch.autograd.gradcheck(func, (t._indices(), t._values().requires_grad_(True), shape, True))
|
||||
|
||||
@expectedFailureMPS
|
||||
@dtypes(*floating_and_complex_types_and(torch.float16, torch.bfloat16))
|
||||
@dtypesIfMPS(*all_mps_types())
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
|
||||
@gradcheck_semantics()
|
||||
def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
|
||||
@ -505,7 +505,8 @@ class TestSparse(TestSparseBase):
|
||||
x.requires_grad_(True)
|
||||
gradcheck(fn, (x,))
|
||||
|
||||
for value_type in [torch.double, torch.cdouble]:
|
||||
values_types = [torch.double, torch.cdouble] if device != "mps:0" else [torch.float32, torch.complex64]
|
||||
for value_type in values_types:
|
||||
i = self.index_tensor([
|
||||
[0, 1, 2, 2],
|
||||
[0, 0, 0, 3],
|
||||
@ -859,8 +860,8 @@ class TestSparse(TestSparseBase):
|
||||
test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0])
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble, torch.bfloat16)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64, torch.bfloat16)
|
||||
@precisionOverride({torch.bfloat16: 2e-2})
|
||||
def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced):
|
||||
# This is for testing torch.copy_(SparseTensor, SparseTensor)
|
||||
@ -883,7 +884,7 @@ class TestSparse(TestSparseBase):
|
||||
x1.copy_(x2)
|
||||
self.assertEqual(x1_dtype, x1.dtype)
|
||||
|
||||
x2 = x2.to(torch.float64)
|
||||
x2 = x2.to(torch.float64) if device != "mps:0" else x2.to(torch.float32)
|
||||
x1_dtype = x1.dtype
|
||||
x1.copy_(x2)
|
||||
self.assertEqual(x1_dtype, x1.dtype)
|
||||
@ -2275,8 +2276,8 @@ class TestSparse(TestSparseBase):
|
||||
test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 0], [9, 12])
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_zeros_like(self, device, dtype, coalesced):
|
||||
def _test_zeros_like(nnzs, template_shape_i, template_shape_v=None):
|
||||
template_shape_v = template_shape_v or []
|
||||
@ -2416,8 +2417,8 @@ class TestSparse(TestSparseBase):
|
||||
yield [dim, start, length]
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_narrow(self, device, dtype, coalesced):
|
||||
shape = [3, 3, 4, 2]
|
||||
input, _, _ = self._gen_sparse(4, 19, shape, dtype, device, coalesced)
|
||||
@ -3278,8 +3279,8 @@ class TestSparse(TestSparseBase):
|
||||
self.assertEqual(list(t.coalesce().values().size()), [1, 3])
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
def test_pickle(self, device, dtype, coalesced):
|
||||
import pickle
|
||||
|
||||
|
Reference in New Issue
Block a user