mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Add boilerplate sparse code support (#157238)
This PR makes minimal changes to support sparse tensors on MPS. In the followup PRs I'll start adding different operations slowly so we can fix the issue of https://github.com/pytorch/pytorch/issues/129842 which is highly requested(I assume because of whisper using sparse tensors) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157238 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
771be85704
commit
a1282b1823
@ -7285,26 +7285,26 @@
|
||||
|
||||
- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_sparse
|
||||
SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_sparse
|
||||
autogen: _sparse_coo_tensor_with_dims.out
|
||||
|
||||
- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta, Meta: new_with_dims_and_tensor_sparse_symint
|
||||
SparseCPU, SparseCUDA, SparseMeta, SparseMPS, Meta: new_with_dims_and_tensor_sparse_symint
|
||||
autogen: _sparse_coo_tensor_with_dims_and_tensors.out
|
||||
|
||||
- func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta: sparse_resize_
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_
|
||||
autogen: sparse_resize, sparse_resize.out
|
||||
|
||||
- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta: sparse_resize_and_clear_
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_resize_and_clear_
|
||||
autogen: sparse_resize_and_clear, sparse_resize_and_clear.out
|
||||
|
||||
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
|
||||
@ -7340,8 +7340,8 @@
|
||||
- func: sparse_dim(Tensor self) -> int
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sparse_dim_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: sparse_dim_sparse_csr
|
||||
CompositeExplicitAutograd: sparse_dim_default
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
@ -7374,8 +7374,8 @@
|
||||
- func: _nnz(Tensor self) -> int
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _nnz_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: _nnz_sparse_csr
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
@ -7396,7 +7396,7 @@
|
||||
- func: is_coalesced(Tensor self) -> bool
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta: is_coalesced_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: is_coalesced_sparse
|
||||
CompositeExplicitAutograd: is_coalesced_default
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
@ -7404,14 +7404,14 @@
|
||||
- func: _indices(Tensor(a) self) -> Tensor(a)
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta: _indices_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _indices_sparse
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
- func: _values(Tensor(a) self) -> Tensor(a)
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta: _values_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _values_sparse
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
|
||||
@ -7421,7 +7421,7 @@
|
||||
- func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA, SparseMeta: _coalesced_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: _coalesced_sparse_
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
autogen: _coalesced, _coalesced.out
|
||||
@ -7510,9 +7510,9 @@
|
||||
- func: _to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: dense_to_sparse
|
||||
SparseCPU, SparseCUDA: sparse_coo_to_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse
|
||||
CPU, CUDA, MPS: dense_to_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta, SparseCsrMPS: sparse_compressed_to_sparse
|
||||
autogen: _to_sparse.sparse_dim_out
|
||||
|
||||
- func: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
|
||||
@ -7522,8 +7522,8 @@
|
||||
- func: _to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA: dense_to_sparse
|
||||
SparseCPU, SparseCUDA: sparse_coo_to_sparse
|
||||
CPU, CUDA, MPS: dense_to_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_coo_to_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_sparse
|
||||
autogen: _to_sparse.out
|
||||
|
||||
|
@ -38,6 +38,8 @@ enum class Backend {
|
||||
SparseCUDA,
|
||||
SparseCsrCPU,
|
||||
SparseCsrCUDA,
|
||||
SparseCsrMPS,
|
||||
SparseMPS,
|
||||
SparseHIP,
|
||||
SparseVE,
|
||||
SparseXPU,
|
||||
@ -94,6 +96,10 @@ inline Backend dispatchKeyToBackend(DispatchKey t) {
|
||||
return Backend::SparseCPU;
|
||||
} else if (t == DispatchKey::SparseCUDA) {
|
||||
return Backend::SparseCUDA;
|
||||
} else if (t == DispatchKey::SparseMPS) {
|
||||
return Backend::SparseMPS;
|
||||
} else if (t == DispatchKey::SparseCsrMPS) {
|
||||
return Backend::SparseCsrMPS;
|
||||
} else if (t == DispatchKey::SparseHIP) {
|
||||
return Backend::SparseHIP;
|
||||
} else if (t == DispatchKey::SparseVE) {
|
||||
@ -172,6 +178,10 @@ inline DispatchKey backendToDispatchKey(Backend b) {
|
||||
return DispatchKey::SparseCPU;
|
||||
case Backend::SparseCUDA:
|
||||
return DispatchKey::SparseCUDA;
|
||||
case Backend::SparseMPS:
|
||||
return DispatchKey::SparseMPS;
|
||||
case Backend::SparseCsrMPS:
|
||||
return DispatchKey::SparseCsrMPS;
|
||||
case Backend::SparseHIP:
|
||||
return DispatchKey::SparseHIP;
|
||||
case Backend::SparseVE:
|
||||
@ -227,6 +237,8 @@ inline DeviceType backendToDeviceType(Backend b) {
|
||||
return DeviceType::CPU;
|
||||
case Backend::CUDA:
|
||||
case Backend::SparseCUDA:
|
||||
case Backend::SparseMPS:
|
||||
case Backend::SparseCsrMPS:
|
||||
case Backend::QuantizedCUDA:
|
||||
case Backend::SparseCsrCUDA:
|
||||
return DeviceType::CUDA;
|
||||
@ -309,6 +321,10 @@ inline const char* toString(Backend b) {
|
||||
return "SparseCPU";
|
||||
case Backend::SparseCUDA:
|
||||
return "SparseCUDA";
|
||||
case Backend::SparseMPS:
|
||||
return "SparseMPS";
|
||||
case Backend::SparseCsrMPS:
|
||||
return "SparseCsrMPS";
|
||||
case Backend::SparseHIP:
|
||||
return "SparseHIP";
|
||||
case Backend::SparseVE:
|
||||
@ -361,6 +377,7 @@ inline bool isSparse(Backend b) {
|
||||
case Backend::SparseXPU:
|
||||
case Backend::SparseCPU:
|
||||
case Backend::SparseCUDA:
|
||||
case Backend::SparseMPS:
|
||||
case Backend::SparseHIP:
|
||||
case Backend::SparseVE:
|
||||
case Backend::SparsePrivateUse1:
|
||||
|
@ -354,6 +354,8 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
|
||||
{"SparseCPU", c10::DispatchKey::SparseCPU},
|
||||
{"SparseCUDA", c10::DispatchKey::SparseCUDA},
|
||||
{"SparseMPS", c10::DispatchKey::SparseMPS},
|
||||
{"SparseCsrMPS", c10::DispatchKey::SparseCsrMPS},
|
||||
{"SparseHIP", c10::DispatchKey::SparseHIP},
|
||||
{"SparseXPU", c10::DispatchKey::SparseXPU},
|
||||
{"SparseVE", c10::DispatchKey::SparseVE},
|
||||
|
@ -32,6 +32,8 @@ inline Layout layout_from_backend(Backend backend) {
|
||||
switch (backend) {
|
||||
case Backend::SparseCPU:
|
||||
case Backend::SparseCUDA:
|
||||
case Backend::SparseMPS:
|
||||
case Backend::SparseCsrMPS:
|
||||
case Backend::SparseHIP:
|
||||
case Backend::SparseVE:
|
||||
case Backend::SparseXPU:
|
||||
@ -46,7 +48,7 @@ inline Layout layout_from_backend(Backend backend) {
|
||||
case Backend::SparseCsrXPU:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU) to a unique layout.");
|
||||
"Cannot map Backend SparseCsr(CPU|CUDA|HIP|VE|XPU|MPS) to a unique layout.");
|
||||
default:
|
||||
return Layout::Strided;
|
||||
}
|
||||
|
@ -12547,6 +12547,58 @@ class TestMetalLibrary(TestCaseMPS):
|
||||
f"Capture file {capture_dirname} contains only metadata, i.e. {capture_listdir}")
|
||||
|
||||
|
||||
|
||||
class TestSparseMPS(TestCaseMPS):
|
||||
def _get_basic_sparse_coo(self, device="mps"):
|
||||
indices = torch.tensor([[0, 1], [2, 0]], dtype=torch.int64, device=device)
|
||||
values = torch.tensor([1, 2], dtype=torch.float32, device=device)
|
||||
size = (2, 3)
|
||||
return torch.sparse_coo_tensor(indices, values, size, device=device)
|
||||
|
||||
def test_sparse_coo_tensor_with_dims(self):
|
||||
indices = torch.zeros((2, 0), dtype=torch.int64, device="mps")
|
||||
values = torch.tensor([], dtype=torch.float32, device="mps")
|
||||
size = (2, 3)
|
||||
t = torch.sparse_coo_tensor(indices, values, size, device="mps")
|
||||
self.assertEqual(t.device.type, "mps")
|
||||
self.assertEqual(t.layout, torch.sparse_coo)
|
||||
|
||||
def test_sparse_coo_tensor_with_dims_and_tensors(self):
|
||||
indices = torch.tensor([[0, 1], [2, 0]], device="mps")
|
||||
values = torch.tensor([1., 2.], device="mps")
|
||||
size = (2, 3)
|
||||
t = torch.sparse_coo_tensor(indices, values, size, device="mps")
|
||||
self.assertEqual(t.device.type, "mps")
|
||||
self.assertEqual(t.layout, torch.sparse_coo)
|
||||
self.assertEqual(t._indices().cpu(), indices.cpu())
|
||||
self.assertEqual(t._values().cpu(), values.cpu())
|
||||
|
||||
def test_nnz(self):
|
||||
t = self._get_basic_sparse_coo()
|
||||
self.assertEqual(t._nnz(), 2)
|
||||
|
||||
def test_sparse_dim(self):
|
||||
t = self._get_basic_sparse_coo()
|
||||
self.assertEqual(t.sparse_dim(), 2)
|
||||
|
||||
def test_to_sparse(self):
|
||||
t = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]], device="mps")
|
||||
x = t.to_sparse()
|
||||
t_cpu = torch.tensor([[[1., 0], [2., 3.]], [[4., 0], [5., 6.]]], device="mps")
|
||||
x_cpu = t.to_sparse()
|
||||
self.assertEqual(x.cpu(), x_cpu)
|
||||
|
||||
def test_resize(self):
|
||||
indices = torch.tensor([[0, 1], [2, 0]])
|
||||
values = torch.tensor([3.0, 4.0])
|
||||
size = torch.Size([2, 3])
|
||||
sparse = torch.sparse_coo_tensor(indices, values, size, device="mps")
|
||||
sparse_cpu = torch.sparse_coo_tensor(indices, values, size, device="cpu")
|
||||
sparse = sparse.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
|
||||
sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
|
||||
self.assertEqual(sparse, sparse_cpu)
|
||||
|
||||
|
||||
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
|
||||
# This requires mps to be properly registered in the device generic test framework which is not the
|
||||
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
|
||||
|
@ -23,6 +23,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTe
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int4pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, int64_t qGroupSize, AtenTensorHandle qScaleAndZeros, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__weight_int8pack_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scales, AtenTensorHandle* ret0);
|
||||
|
@ -556,6 +556,7 @@ void check_base_legacy_new(
|
||||
c10::DispatchKey::SparseCUDA,
|
||||
c10::DispatchKey::SparseHIP,
|
||||
c10::DispatchKey::SparseXPU,
|
||||
c10::DispatchKey::SparseMPS,
|
||||
c10::DispatchKey::SparsePrivateUse1,
|
||||
});
|
||||
TORCH_CHECK(
|
||||
|
@ -39,6 +39,8 @@ const char* backend_to_string(const at::Backend& backend) {
|
||||
return "torch.cuda.sparse";
|
||||
case at::Backend::SparseXPU:
|
||||
return "torch.xpu.sparse";
|
||||
case at::Backend::SparseMPS:
|
||||
return "torch.mps.sparse";
|
||||
case at::Backend::QuantizedCPU:
|
||||
return "torch.quantized";
|
||||
case at::Backend::HPU:
|
||||
|
@ -288,6 +288,8 @@ dispatch_keys = [
|
||||
DispatchKey.SparseCsrXPU,
|
||||
DispatchKey.SparseCUDA,
|
||||
DispatchKey.SparseCsrCUDA,
|
||||
DispatchKey.SparseMPS,
|
||||
DispatchKey.SparseCsrMPS,
|
||||
DispatchKey.QuantizedCPU,
|
||||
DispatchKey.QuantizedCUDA,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
|
Reference in New Issue
Block a user