[MPS] Add boilerplate sparse code support (#157238)

This PR makes minimal changes to support sparse tensors on MPS. In the followup PRs I'll start adding different operations slowly so we can fix the issue of
https://github.com/pytorch/pytorch/issues/129842
which is highly requested(I assume because of whisper using sparse tensors)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157238
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20
2025-06-30 01:53:45 +00:00
committed by PyTorch MergeBot
parent 771be85704
commit a1282b1823
9 changed files with 97 additions and 18 deletions

View File

@ -7285,26 +7285,26 @@
- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
dispatch:
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse
SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_sparse
autogen: _sparse_coo_tensor_with_dims.out
- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
dispatch:
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint
SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_and_tensor_sparse_symint
autogen: _sparse_coo_tensor_with_dims_and_tensors.out
- func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMeta: sparse_resize_
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_
autogen: sparse_resize, sparse_resize.out
- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_and_clear_
autogen: sparse_resize_and_clear, sparse_resize_and_clear.out
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
@ -7340,8 +7340,8 @@
- func: sparse_dim(Tensor self) -> int
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_dim_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: sparse_dim_sparse_csr
CompositeExplicitAutograd: sparse_dim_default
device_check: NoCheck
device_guard: False
@ -7374,8 +7374,8 @@
- func: _nnz(Tensor self) -> int
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _nnz_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: _nnz_sparse_csr
device_check: NoCheck
device_guard: False
@ -7396,7 +7396,7 @@
- func: is_coalesced(Tensor self) -> bool
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: is_coalesced_sparse
CompositeExplicitAutograd: is_coalesced_default
device_check: NoCheck
device_guard: False
@ -7404,14 +7404,14 @@
- func: _indices(Tensor(a) self) -> Tensor(a)
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMeta: _indices_sparse
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _indices_sparse
device_check: NoCheck
device_guard: False
- func: _values(Tensor(a) self) -> Tensor(a)
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMeta: _values_sparse
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _values_sparse
device_check: NoCheck
device_guard: False
@ -7421,7 +7421,7 @@
- func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _coalesced_sparse_
device_check: NoCheck
device_guard: False
autogen: _coalesced, _coalesced.out
@ -7510,9 +7510,9 @@
- func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
variants: method
dispatch:
CPU, CUDA: dense_to_sparse
SparseCPU, SparseCUDA: sparse_coo_to_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse
CPU, CUDA, MPS: dense_to_sparse
SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta, SparseCsrMPS: sparse_compressed_to_sparse
autogen: _to_sparse.sparse_dim_out
- func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
@ -7522,8 +7522,8 @@
- func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
variants: method
dispatch:
CPU, CUDA: dense_to_sparse
SparseCPU, SparseCUDA: sparse_coo_to_sparse
CPU, CUDA, MPS: dense_to_sparse
SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse
autogen: _to_sparse.out

View File

@ -38,6 +38,8 @@ enum class Backend {
SparseCUDA,
SparseCsrCPU,
SparseCsrCUDA,
SparseCsrMPS,
SparseMPS,
SparseHIP,
SparseVE,
SparseXPU,
@ -94,6 +96,10 @@ inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::SparseCPU;
} else if (t == DispatchKey::SparseCUDA) {
return Backend::SparseCUDA;
} else if (t == DispatchKey::SparseMPS) {
return Backend::SparseMPS;
} else if (t == DispatchKey::SparseCsrMPS) {
return Backend::SparseCsrMPS;
} else if (t == DispatchKey::SparseHIP) {
return Backend::SparseHIP;
} else if (t == DispatchKey::SparseVE) {
@ -172,6 +178,10 @@ inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::SparseCPU;
case Backend::SparseCUDA:
return DispatchKey::SparseCUDA;
case Backend::SparseMPS:
return DispatchKey::SparseMPS;
case Backend::SparseCsrMPS:
return DispatchKey::SparseCsrMPS;
case Backend::SparseHIP:
return DispatchKey::SparseHIP;
case Backend::SparseVE:
@ -227,6 +237,8 @@ inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::CPU;
case Backend::CUDA:
case Backend::SparseCUDA:
case Backend::SparseMPS:
case Backend::SparseCsrMPS:
case Backend::QuantizedCUDA:
case Backend::SparseCsrCUDA:
return DeviceType::CUDA;
@ -309,6 +321,10 @@ inline const char* toString(Backend b) {
return "SparseCPU";
case Backend::SparseCUDA:
return "SparseCUDA";
case Backend::SparseMPS:
return "SparseMPS";
case Backend::SparseCsrMPS:
return "SparseCsrMPS";
case Backend::SparseHIP:
return "SparseHIP";
case Backend::SparseVE:
@ -361,6 +377,7 @@ inline bool isSparse(Backend b) {
case Backend::SparseXPU:
case Backend::SparseCPU:
case Backend::SparseCUDA:
case Backend::SparseMPS:
case Backend::SparseHIP:
case Backend::SparseVE:
case Backend::SparsePrivateUse1:

View File

@ -354,6 +354,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"SparseCPU", c10::DispatchKey::SparseCPU},
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
{"SparseMPS", c10::DispatchKey::SparseMPS},
{"SparseCsrMPS", c10::DispatchKey::SparseCsrMPS},
{"SparseHIP", c10::DispatchKey::SparseHIP},
{"SparseXPU", c10::DispatchKey::SparseXPU},
{"SparseVE", c10::DispatchKey::SparseVE},

View File

@ -32,6 +32,8 @@ inline Layout layout_from_backend(Backend backend) {
switch (backend) {
case Backend::SparseCPU:
case Backend::SparseCUDA:
case Backend::SparseMPS:
case Backend::SparseCsrMPS:
case Backend::SparseHIP:
case Backend::SparseVE:
case Backend::SparseXPU:
@ -46,7 +48,7 @@ inline Layout layout_from_backend(Backend backend) {
case Backend::SparseCsrXPU:
TORCH_CHECK(
false,
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout.");
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU|MPS) to a unique layout.");
default:
return Layout::Strided;
}

View File

@ -12547,6 +12547,58 @@ class TestMetalLibrary(TestCaseMPS):
f"Capture file {capture_dirname} contains only metadata, i.e. {capture_listdir}")
class TestSparseMPS(TestCaseMPS):
def _get_basic_sparse_coo(self, device="mps"):
indices = torch.tensor([[0, 1], [2, 0]], dtype=torch.int64, device=device)
values = torch.tensor([1, 2], dtype=torch.float32, device=device)
size = (2, 3)
return torch.sparse_coo_tensor(indices, values, size, device=device)
def test_sparse_coo_tensor_with_dims(self):
indices = torch.zeros((2, 0), dtype=torch.int64, device="mps")
values = torch.tensor([], dtype=torch.float32, device="mps")
size = (2, 3)
t = torch.sparse_coo_tensor(indices, values, size, device="mps")
self.assertEqual(t.device.type, "mps")
self.assertEqual(t.layout, torch.sparse_coo)
def test_sparse_coo_tensor_with_dims_and_tensors(self):
indices = torch.tensor([[0, 1], [2, 0]], device="mps")
values = torch.tensor([1., 2.], device="mps")
size = (2, 3)
t = torch.sparse_coo_tensor(indices, values, size, device="mps")
self.assertEqual(t.device.type, "mps")
self.assertEqual(t.layout, torch.sparse_coo)
self.assertEqual(t._indices().cpu(), indices.cpu())
self.assertEqual(t._values().cpu(), values.cpu())
def test_nnz(self):
t = self._get_basic_sparse_coo()
self.assertEqual(t._nnz(), 2)
def test_sparse_dim(self):
t = self._get_basic_sparse_coo()
self.assertEqual(t.sparse_dim(), 2)
def test_to_sparse(self):
t = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]], device="mps")
x = t.to_sparse()
t_cpu = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]], device="mps")
x_cpu = t.to_sparse()
self.assertEqual(x.cpu(), x_cpu)
def test_resize(self):
indices = torch.tensor([[0, 1], [2, 0]])
values = torch.tensor([3.0, 4.0])
size = torch.Size([2, 3])
sparse = torch.sparse_coo_tensor(indices, values, size, device="mps")
sparse_cpu = torch.sparse_coo_tensor(indices, values, size, device="cpu")
sparse = sparse.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
self.assertEqual(sparse, sparse_cpu)
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342

View File

@ -23,6 +23,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTe
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);

View File

@ -556,6 +556,7 @@ void check_base_legacy_new(
c10::DispatchKey::SparseCUDA,
c10::DispatchKey::SparseHIP,
c10::DispatchKey::SparseXPU,
c10::DispatchKey::SparseMPS,
c10::DispatchKey::SparsePrivateUse1,
});
TORCH_CHECK(

View File

@ -39,6 +39,8 @@ const char* backend_to_string(const at::Backend& backend) {
return "torch.cuda.sparse";
case at::Backend::SparseXPU:
return "torch.xpu.sparse";
case at::Backend::SparseMPS:
return "torch.mps.sparse";
case at::Backend::QuantizedCPU:
return "torch.quantized";
case at::Backend::HPU:

View File

@ -288,6 +288,8 @@ dispatch_keys = [
DispatchKey.SparseCsrXPU,
DispatchKey.SparseCUDA,
DispatchKey.SparseCsrCUDA,
DispatchKey.SparseMPS,
DispatchKey.SparseCsrMPS,
DispatchKey.QuantizedCPU,
DispatchKey.QuantizedCUDA,
DispatchKey.CompositeImplicitAutograd,