mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[MPS] enable cat op for sparse (#162007)"
This reverts commit 2c03f0acc53ed13fe8ebfe809129f25996e009a0. Reverted https://github.com/pytorch/pytorch/pull/162007 on behalf of https://github.com/jeanschmidt due to Breaks internal builds see [D81588372](https://www.internalfb.com/diff/D81588372), @malfet may you help the author? ([comment](https://github.com/pytorch/pytorch/pull/162007#issuecomment-3255357336))
This commit is contained in:
@ -1412,7 +1412,7 @@
|
||||
- func: cat(Tensor[] tensors, int dim=0) -> Tensor
|
||||
structured_delegate: cat.out
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMPS: cat_sparse
|
||||
SparseCPU, SparseCUDA: cat_sparse
|
||||
QuantizedCPU: cat_quantized_cpu
|
||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: cat_nested
|
||||
tags: core
|
||||
|
||||
@ -1121,9 +1121,9 @@ class TestSparse(TestSparseBase):
|
||||
x.sub_(2 * x)
|
||||
self.assertLessEqual(x._nnz(), 10)
|
||||
|
||||
@expectedFailureMPS
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_cat(self, device, dtype, coalesced):
|
||||
# shapes: list of tuples (sparse_dims, nnz, sizes)
|
||||
def test_shapes(shapes, dim, fail_message=None):
|
||||
|
||||
Reference in New Issue
Block a user