[MPS] sparse matmuls (#165232)

Implements matmuls for sparse tensors. With this commit most of the core sparse operations should be implemented. Fixes:
https://github.com/pytorch/pytorch/issues/156540
https://github.com/pytorch/pytorch/issues/129842

Should be merged after:
https://github.com/pytorch/pytorch/pull/165102

To compare MPS and CPU, you can use this script:
```python
import torch
import time
import matplotlib.pyplot as plt

B, I, J, K = 8, 20000, 20000, 20000
num_iterations = 500

nnz_values = [10, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 100000]
speedups = []

for nnz in nnz_values:
    indices = torch.stack([
        torch.randint(0, B, (nnz,)),
        torch.randint(0, I, (nnz,)),
        torch.randint(0, J, (nnz,)),
    ])
    values = torch.rand(nnz)

    sparse = torch.sparse_coo_tensor(indices, values, size=(B, I, J), device="mps").coalesce()
    dense = torch.randn(B, J, 200, device="mps")

    t1 = time.time()
    for _ in range(num_iterations):
        result = torch.bmm(sparse, dense)
    torch.mps.synchronize()
    t2 = time.time()
    mps_time = (t2 - t1) / num_iterations

    sparse_cpu = sparse.cpu()
    dense_cpu = dense.cpu()
    t1 = time.time()
    for _ in range(num_iterations):
        result_cpu = torch.bmm(sparse_cpu, dense_cpu)
    t2 = time.time()
    cpu_time = (t2 - t1) / num_iterations

    speedup = cpu_time / mps_time
    speedups.append(speedup)
    print(f"nnz={nnz}: MPS={mps_time:.6f}s, CPU={cpu_time:.6f}s, Speedup={speedup:.2f}x")

plt.figure(figsize=(10, 6))
plt.plot(nnz_values, speedups, marker='o', linewidth=2, markersize=8)
plt.xlabel('Number of Non-Zero Elements (nnz)', fontsize=12)
plt.ylabel('Speedup (CPU time / MPS time)', fontsize=12)
plt.title('MPS vs CPU Speedup for Sparse-Dense BMM', fontsize=14)
plt.grid(True, alpha=0.3)
plt.axhline(y=1, color='r', linestyle='--', alpha=0.5)
plt.xscale('log')
plt.tight_layout()
plt.show()
```

## Tested on M1 Pro
<img width="1000" height="600" alt="Figure_1" src="https://github.com/user-attachments/assets/4a2402ec-3dc4-402d-8196-a0426906ca3d" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165232
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20
2025-10-18 09:04:42 +00:00
committed by PyTorch MergeBot
parent fdab48a7c1
commit ad67170c8b
5 changed files with 527 additions and 14 deletions

View File

@ -1370,6 +1370,7 @@
dispatch:
SparseCPU: bmm_sparse_cpu
SparseCUDA: bmm_sparse_cuda
SparseMPS: bmm_sparse_mps
NestedTensorCPU: bmm_nested
NestedTensorCUDA: bmm_nested_cuda
tags: core
@ -1385,6 +1386,7 @@
MTIA: bmm_out_mtia
SparseCPU: bmm_out_sparse_cpu
SparseCUDA: bmm_out_sparse_cuda
SparseMPS: bmm_out_sparse_mps
SparseCsrCUDA: bmm_out_sparse_csr_cuda
- func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
@ -4173,7 +4175,7 @@
structured_delegate: mm.out
variants: function, method
dispatch:
SparseCPU, SparseCUDA: _sparse_mm
SparseCPU, SparseCUDA, SparseMPS: _sparse_mm
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm
tags: core
@ -7112,6 +7114,7 @@
MTIA: addmm_out_mtia
SparseCPU: addmm_out_sparse_dense_cpu
SparseCUDA: addmm_out_sparse_dense_cuda
SparseMPS: addmm_out_sparse_dense_mps
SparseCsrCPU: addmm_out_sparse_compressed_cpu
SparseCsrCUDA: addmm_out_sparse_compressed_cuda
@ -7121,6 +7124,7 @@
dispatch:
SparseCPU: addmm_sparse_dense_cpu
SparseCUDA: addmm_sparse_dense_cuda
SparseMPS: addmm_sparse_dense_mps
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense
tags: core

View File

@ -1,5 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/sparse/SparseStubs.h>
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
@ -18,6 +19,8 @@
#include <ATen/ops/ones_like.h>
#include <ATen/ops/argsort.h>
#include <ATen/ops/result_type.h>
#include <ATen/ops/bmm_native.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/copy_sparse_to_sparse.h>
#include <ATen/ops/mul.h>
#endif
@ -33,6 +36,305 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/Mul_metallib.h>
#endif
static Tensor& s_addmm_out_sparse_dense_mps(
Tensor& r,
const Tensor& t,
const SparseTensor& sparse_,
const Tensor& dense,
const Scalar& beta,
const Scalar& alpha) {
TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: sparse_dim must be 2, got ", sparse_.sparse_dim());
TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: sparse values must be 0-dense-dim, got ", sparse_.dense_dim());
TORCH_CHECK(dense.dim() == 2, "addmm: 'dense' must be 2D, got ", dense.dim());
TORCH_CHECK(t.dim() == 2, "addmm: 't' must be 2D, got ", t.dim());
const int64_t I = sparse_.size(0);
const int64_t J = sparse_.size(1);
const int64_t K = dense.size(1);
TORCH_CHECK(dense.size(0) == J,
"addmm: dense (mat2) dim0 must be ", J, ", got ", dense.size(0));
TORCH_CHECK(t.size(0) == I && t.size(1) == K,
"addmm: 't' shape must be (", I, ", ", K, "), got (", t.size(0), ", ", t.size(1), ")");
r.resize_({I, K});
auto sparse = sparse_.coalesce();
const int64_t nnz = sparse._nnz();
if (nnz == 0 || I == 0 || K == 0) {
at::mul_out(r, t, beta);
return r;
}
const auto v_dtype = sparse._values().scalar_type();
const auto d_dtype = dense.scalar_type();
const auto t_dtype = t.scalar_type();
auto compute_dtype = c10::promoteTypes(c10::promoteTypes(v_dtype, d_dtype), t_dtype);
TORCH_CHECK(canCast(compute_dtype, r.scalar_type()),
"Can't convert computed type ", compute_dtype, " to output ", r.scalar_type());
auto indices2d = sparse._indices().contiguous();
auto values = sparse._values().to(compute_dtype);
auto dense_c = dense.to(compute_dtype).contiguous();
auto t_c = t.to(compute_dtype).contiguous();
const bool out_needs_cast = (r.scalar_type() != compute_dtype) || !r.is_contiguous();
Tensor out_buf = out_needs_cast
? at::empty({I, K}, r.options().dtype(compute_dtype))
: r;
auto out_contig = out_buf.contiguous();
auto device = r.device();
auto stream = getCurrentMPSStream();
const float alpha_f = alpha.to<float>();
const float beta_f = beta.to<float>();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
const std::string func = "spmm_addmm_coo_" + mps::scalarToMetalTypeString(values);
auto pso = lib.getPipelineStateForFunc(func);
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t gridX = static_cast<uint32_t>(K);
const uint32_t gridZ = static_cast<uint32_t>(I);
const uint32_t tgW = std::min<uint32_t>(gridX, tew);
MTLSize grid = MTLSizeMake(gridX, 1, gridZ);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
indices2d,
values,
dense_c,
t_c,
out_contig,
std::array<uint32_t, 3>{static_cast<uint32_t>(I),
static_cast<uint32_t>(J),
static_cast<uint32_t>(K)},
std::array<float, 2>{alpha_f, beta_f},
static_cast<uint32_t>(nnz));
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
if (out_needs_cast) {
r.copy_(out_contig.to(r.scalar_type()));
}
return r;
}
static void build_batch_ptr_mps(
const Tensor& indices_dim0,
int64_t B,
Tensor& batch_ptr
) {
// Builds an array of pointers which point to each batches elements. Example:
// idx_b = [0, 0, 0, 1, 1, 2, 2, 2, 2] // 9 non-zero elements
// └─────┘ └──┘ └─────────┘
// batch 0 batch 1 batch 2
// batch_ptr = [0, 3, 5, 9]
// │ │ │ └─ end of batch 2 (total nnz)
// │ │ └──── batch 2 starts at index 5
// │ └─────── batch 1 starts at index 3
// └────────── batch 0 starts at index 0
TORCH_CHECK(indices_dim0.is_mps() && batch_ptr.is_mps(), "MPS device expected");
auto device = indices_dim0.device();
auto stream = getCurrentMPSStream();
const int64_t nnz = indices_dim0.numel();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("build_batch_ptr_from_sorted_batches");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t Q = static_cast<uint32_t>(B + 1);
const uint32_t tgW = std::min<uint32_t>(Q, tew);
MTLSize grid = MTLSizeMake(Q, 1, 1);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
indices_dim0,
batch_ptr,
std::array<uint32_t, 2>{static_cast<uint32_t>(nnz),
static_cast<uint32_t>(B)});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
}
static void build_row_ptr_per_batch_mps(
const Tensor& rows,
const Tensor& batch_ptr,
int64_t B,
int64_t I,
Tensor& row_ptr
) {
// Build per-batch CSR-style row pointer arrays from row indices sorted by batch
// Given:
// rows: 1-D array of length nnz with row ids in [0, I), sorted within each batch
// batch_ptr: length B+1, where [batch_ptr[b], batch_ptr[b+1]) is the subrange for batch b
// Produces:
// - row_ptr: shape [B, I+1]
//
// Example (B = 2, I = 4):
// rows = [0, 0, 1, 3, 0, 2, 2] // 7 non-zero elements
// └─── batch 0 ──┘ └─ batch 1 ─┘
// batch_ptr = [0, 4, 7]
// │ │ └─ end of batch 1 (total nnz)
// │ └──── end of batch 0/start of batch 1
// └─────── start of batch 0
//
// per-batch row pointers (I+1 entries each):
// row_ptr[0] = [0, 2, 3, 3, 4]
// row_ptr[1] = [0, 1, 1, 3, 3]
// laid out in memory: [0, 2, 3, 3, 4, 0, 1, 1, 3, 3]
TORCH_CHECK(rows.is_mps() && batch_ptr.is_mps() && row_ptr.is_mps(), "MPS device expected");
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("build_row_ptr_from_sorted_rows_by_batch");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t Qx = static_cast<uint32_t>(I + 1);
const uint32_t Qy = static_cast<uint32_t>(B);
const uint32_t tgW = std::min<uint32_t>(Qx, tew);
MTLSize grid = MTLSizeMake(Qx, Qy, 1);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
rows,
batch_ptr,
row_ptr,
std::array<uint32_t, 2>{static_cast<uint32_t>(I),
static_cast<uint32_t>(B)});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
}
Tensor& bmm_out_sparse_mps(const SparseTensor& self_, const Tensor& mat2_, Tensor& result_) {
TORCH_CHECK(result_.is_mps(), "bmm_sparse: expected 'out' to be MPS, got ", result_.device());
TORCH_CHECK(self_.is_mps(), "bmm_sparse: expected 'self' to be MPS, got ", self_.device());
TORCH_CHECK(mat2_.is_mps(), "bmm_sparse: expected 'mat2' to be MPS, got ", mat2_.device());
TORCH_CHECK(self_.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self_.dense_dim());
TORCH_CHECK(self_.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self_.sparse_dim());
TORCH_CHECK(mat2_.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2_.dim());
TORCH_CHECK(self_.size(0) == mat2_.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match");
TORCH_CHECK(self_.size(2) == mat2_.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match");
const int64_t B = self_.size(0);
const int64_t I = self_.size(1);
const int64_t J = self_.size(2);
const int64_t K = mat2_.size(2);
auto self = self_.coalesce();
const int64_t nnz = self._nnz();
if (nnz == 0) {
return result_.zero_();
}
const auto computeDtype = at::kFloat;
auto indices = self._indices();
auto values = self._values();
auto values_c = values.scalar_type() == computeDtype ? values : values.to(computeDtype);
auto mat2_c = mat2_.scalar_type() == computeDtype ? mat2_ : mat2_.to(computeDtype);
auto mat2_contig = mat2_c.contiguous();
auto idx_b = indices.select(0, 0).contiguous();
auto idx_i = indices.select(0, 1).contiguous();
auto idx_j = indices.select(0, 2).contiguous();
// builds an array of pointers of where the batch_idx's pointer starts and ends
// look in function for better explanation
auto batch_ptr = at::empty({B + 1}, at::device(result_.device()).dtype(kLong));
build_batch_ptr_mps(idx_b, B, batch_ptr);
// build row_ptr per batch: for each (b, i) get [start, end) into rows/cols/vals
auto row_ptr = at::empty({B * (I + 1)}, at::device(result_.device()).dtype(kLong));
build_row_ptr_per_batch_mps(idx_i, batch_ptr, B, I, row_ptr);
const bool out_needs_cast = (result_.scalar_type() != computeDtype) || !result_.is_contiguous();
Tensor out_buf = out_needs_cast
? at::empty({B, I, K}, result_.options().dtype(computeDtype))
: result_;
auto out_contig = out_buf.contiguous();
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("spmm_bmm_coo_rows_grouped_" + mps::scalarToMetalTypeString(values));
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t tgW = std::min<uint32_t>((uint32_t)K, tew);
// One threadgroup per (row i, batch b), lanes cover K
MTLSize grid = MTLSizeMake(tgW, (uint32_t)I, (uint32_t)B);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
idx_i,
idx_j,
values_c,
mat2_contig,
out_contig,
row_ptr,
std::array<uint32_t, 4>{(uint32_t)B, (uint32_t)I, (uint32_t)J, (uint32_t)K});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
if (out_needs_cast) {
result_.copy_(out_contig.to(result_.scalar_type()));
}
return result_;
}
Tensor bmm_sparse_mps(const Tensor& self, const Tensor& mat2) {
Tensor result = at::zeros({self.size(0), self.size(1), mat2.size(2)}, mat2.options());
return bmm_out_sparse_mps(self, mat2, result);
}
Tensor& addmm_out_sparse_dense_mps(
const Tensor& self,
const SparseTensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha);
}
Tensor addmm_sparse_dense_mps(
const Tensor& self,
const SparseTensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha
) {
c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
Tensor result = at::empty({0}, self.options());
return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha);
}
static SparseTensor& mul_out_dense_sparse_mps(
const Tensor& dense,
const Tensor& sparse,

View File

@ -1,10 +1,105 @@
#include <metal_stdlib>
#include <c10/metal/indexing.h>
#include <c10/metal/utils.h>
using namespace c10::metal;
using namespace metal;
inline uint lower_bound_i64(device const long* arr, uint lo, uint hi, long key) {
uint l = lo, r = hi;
while (l < r) {
uint m = (l + r) >> 1;
long v = arr[m];
if (v < key) {
l = m + 1;
} else {
r = m;
}
}
return l;
}
template <typename T> struct MulAccum { using type = float; };
template <> struct MulAccum<float2> { using type = float2; };
inline uint upper_bound_i64(device const long* arr, uint lo, uint hi, long key) {
uint l = lo, r = hi;
while (l < r) {
uint m = (l + r) >> 1;
long v = arr[m];
if (v <= key) {
l = m + 1;
} else {
r = m;
}
}
return l;
}
kernel void build_row_ptr_from_sorted_rows_by_batch(
device const long* rows [[buffer(0)]],
device const long* batch_ptr [[buffer(1)]],
device long* row_ptr [[buffer(2)]],
constant uint2& dims [[buffer(3)]],
uint3 tid [[thread_position_in_grid]])
{
const uint I = dims.x;
const uint B = dims.y;
const uint i = tid.x;
const uint b = tid.y;
if (b >= B || i > I) return;
const uint base = (uint)batch_ptr[b];
const uint lim = (uint)batch_ptr[b + 1];
const ulong out_base = (ulong)b * (ulong)(I + 1);
if (i == I) {
row_ptr[out_base + (ulong)I] = (long)lim;
} else {
const long key = (long)i;
const uint pos = lower_bound_i64(rows, base, lim, key);
row_ptr[out_base + (ulong)i] = (long)pos;
}
}
template <typename T>
kernel void spmm_bmm_coo_rows_grouped(
device const long* rows [[buffer(0)]],
device const long* cols [[buffer(1)]],
device const T* vals [[buffer(2)]],
device const T* dense [[buffer(3)]],
device T* out [[buffer(4)]],
device const long* row_ptr [[buffer(5)]],
constant uint4& dims [[buffer(6)]],
uint3 tid [[thread_position_in_grid]],
uint3 ltid [[thread_position_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]])
{
const uint B = dims.x;
const uint I = dims.y;
const uint J = dims.z;
const uint K = dims.w;
const uint b = tid.z;
const uint i = tid.y;
const uint lane = ltid.x;
const uint tgW = tptg.x;
const ulong rp_base = (ulong)b * (ulong)(I + 1);
const uint start = (uint)row_ptr[rp_base + (ulong)i];
const uint end = (uint)row_ptr[rp_base + (ulong)i + 1];
for (uint k = lane; k < K; k += tgW) {
auto acc = static_cast<accum_t<T>>(T(0));
for (uint p = start; p < end; ++p) {
const uint c = (uint)cols[p];
const auto v = static_cast<accum_t<T>>(vals[p]);
const uint d_off = ((b * J) + c) * K + k;
const auto d = static_cast<accum_t<T>>(dense[d_off]);
acc += mul(v, d);
}
const uint y_off = ((b * I) + i) * K + k;
out[y_off] = static_cast<T>(acc);
}
}
template <typename T>
kernel void dense_sparse_mul_kernel(
@ -32,10 +127,9 @@ kernel void dense_sparse_mul_kernel(
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
using accum_t = typename MulAccum<T>::type;
const accum_t a = static_cast<accum_t>(values[val_idx]);
const accum_t b = static_cast<accum_t>(dense[dense_idx]);
out_values[val_idx] = static_cast<T>(a * b);
const auto a = static_cast<accum_t<T>>(values[val_idx]);
const auto b = static_cast<accum_t<T>>(dense[dense_idx]);
out_values[val_idx] = static_cast<T>(mul(a, b));
}
kernel void intersect_binary_search(
@ -120,6 +214,76 @@ kernel void fused_gather_mul_kernel(
}
}
kernel void build_batch_ptr_from_sorted_batches(
device const long* batches [[buffer(0)]],
device long* batch_ptr [[buffer(1)]],
constant uint2& nnz_B [[buffer(2)]],
uint3 tid [[thread_position_in_grid]])
{
uint b = tid.x;
uint nnz = nnz_B.x;
uint batch = nnz_B.y;
if (b == batch) {
batch_ptr[b] = (long)nnz;
return;
}
uint lo = 0;
uint hi = nnz;
long key = (long)b;
while (lo < hi) {
uint mid = (lo + hi) >> 1;
long v = batches[mid];
if (v < key) lo = mid + 1;
else hi = mid;
}
batch_ptr[b] = (long)lo;
}
template <typename T>
kernel void spmm_addmm_coo(
device const long* indices2d [[buffer(0)]],
device const T* vals [[buffer(1)]],
device const T* dense [[buffer(2)]],
device const T* t_in [[buffer(3)]],
device T* out [[buffer(4)]],
constant uint3& dims [[buffer(5)]],
constant float2& alpha_beta [[buffer(6)]],
constant uint& nnz [[buffer(7)]],
uint3 tid [[thread_position_in_grid]])
{
const uint K = dims.z;
const uint k = tid.x;
const uint i = tid.z;
const float alpha = alpha_beta.x;
const float beta = alpha_beta.y;
device const long* rows = indices2d;
device const long* cols = indices2d + nnz;
const uint start = lower_bound_i64(rows, 0u, nnz, (long)i);
const uint end = upper_bound_i64(rows, 0u, nnz, (long)i);
// accumulator is float for scalar/half/bfloat and float2 for float2
auto acc = static_cast<accum_t<T>>(T(0));
for (uint p = start; p < end; ++p) {
const uint c = (uint)cols[p];
const auto v = static_cast<accum_t<T>>(vals[p]);
const uint dense_off = c * K + k;
const auto d = static_cast<accum_t<T>>(dense[dense_off]);
acc += mul(v, d);
}
const uint off = i * K + k;
const auto base = (beta != 0.0f) ? (static_cast<accum_t<T>>(t_in[off]) * beta) : static_cast<accum_t<T>>(T(0));
const auto y = base + alpha * acc;
out[off] = static_cast<T>(y);
}
#define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE) \
template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void \
dense_sparse_mul_kernel<DTYPE>( \
@ -151,6 +315,36 @@ INSTANTIATE_DENSE_SPARSE_MUL(float2);
constant uint2& dims_output [[buffer(8)]], \
uint3 gid [[thread_position_in_grid]]);
INSTANTIATE_FUSED_GATHER_MUL(float);
INSTANTIATE_FUSED_GATHER_MUL(half);
INSTANTIATE_FUSED_GATHER_MUL(bfloat);
INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \
template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void \
spmm_bmm_coo_rows_grouped<DTYPE>( \
device const long* rows [[buffer(0)]], \
device const long* cols [[buffer(1)]], \
device const DTYPE* vals [[buffer(2)]], \
device const DTYPE* dense [[buffer(3)]], \
device DTYPE* out [[buffer(4)]], \
device const long* row_ptr [[buffer(5)]], \
constant uint4& dims [[buffer(6)]], \
uint3 tid [[thread_position_in_grid]], \
uint3 ltid [[thread_position_in_threadgroup]], \
uint3 tptg [[threads_per_threadgroup]]);
INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED);
#define INSTANTIATE_SPMM_ADDMM_COO(DTYPE) \
template [[host_name("spmm_addmm_coo_" #DTYPE)]] kernel void \
spmm_addmm_coo<DTYPE>( \
device const long* indices2d [[buffer(0)]], \
device const DTYPE* vals [[buffer(1)]], \
device const DTYPE* dense [[buffer(2)]], \
device const DTYPE* t_in [[buffer(3)]], \
device DTYPE* out [[buffer(4)]], \
constant uint3& dims [[buffer(5)]], \
constant float2& alpha_beta [[buffer(6)]], \
constant uint& nnz [[buffer(7)]], \
uint3 tid [[thread_position_in_grid]]);
INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_ADDMM_COO);

View File

@ -328,5 +328,21 @@ struct pair {
T2 second;
};
#define INSTANTIATE_FOR_ALL_TYPES(MACRO) \
MACRO(float); \
MACRO(half); \
MACRO(bfloat); \
MACRO(float2); \
MACRO(long); \
MACRO(char); \
MACRO(uchar); \
MACRO(short); \
MACRO(int);
#define INSTANTIATE_FOR_FLOAT_TYPES(MACRO) \
MACRO(float); \
MACRO(half); \
MACRO(bfloat);
} // namespace metal
} // namespace c10

View File

@ -1421,7 +1421,6 @@ class TestSparse(TestSparseBase):
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
def test_bmm(self, device, dtype, coalesced):
@ -1633,7 +1632,6 @@ class TestSparse(TestSparseBase):
self.assertEqual(self.safeToDense(res), self.safeToDense(true_result))
@coalescedonoff
@expectedFailureMPS
@precisionOverride({torch.bfloat16: 5e-2, torch.float16: 5e-2})
@dtypes(torch.double, torch.cdouble, torch.bfloat16, torch.float16)
@dtypesIfMPS(torch.float32, torch.complex64, torch.bfloat16, torch.float16)
@ -1724,7 +1722,6 @@ class TestSparse(TestSparseBase):
# test_shape(2, 3, [2, 2, 0])
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
def test_dsmm(self, device, dtype, coalesced):