mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] sparse matmuls (#165232)
Implements matmuls for sparse tensors. With this commit most of the core sparse operations should be implemented. Fixes: https://github.com/pytorch/pytorch/issues/156540 https://github.com/pytorch/pytorch/issues/129842 Should be merged after: https://github.com/pytorch/pytorch/pull/165102 To compare MPS and CPU, you can use this script: ```python import torch import time import matplotlib.pyplot as plt B, I, J, K = 8, 20000, 20000, 20000 num_iterations = 500 nnz_values = [10, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 100000] speedups = [] for nnz in nnz_values: indices = torch.stack([ torch.randint(0, B, (nnz,)), torch.randint(0, I, (nnz,)), torch.randint(0, J, (nnz,)), ]) values = torch.rand(nnz) sparse = torch.sparse_coo_tensor(indices, values, size=(B, I, J), device="mps").coalesce() dense = torch.randn(B, J, 200, device="mps") t1 = time.time() for _ in range(num_iterations): result = torch.bmm(sparse, dense) torch.mps.synchronize() t2 = time.time() mps_time = (t2 - t1) / num_iterations sparse_cpu = sparse.cpu() dense_cpu = dense.cpu() t1 = time.time() for _ in range(num_iterations): result_cpu = torch.bmm(sparse_cpu, dense_cpu) t2 = time.time() cpu_time = (t2 - t1) / num_iterations speedup = cpu_time / mps_time speedups.append(speedup) print(f"nnz={nnz}: MPS={mps_time:.6f}s, CPU={cpu_time:.6f}s, Speedup={speedup:.2f}x") plt.figure(figsize=(10, 6)) plt.plot(nnz_values, speedups, marker='o', linewidth=2, markersize=8) plt.xlabel('Number of Non-Zero Elements (nnz)', fontsize=12) plt.ylabel('Speedup (CPU time / MPS time)', fontsize=12) plt.title('MPS vs CPU Speedup for Sparse-Dense BMM', fontsize=14) plt.grid(True, alpha=0.3) plt.axhline(y=1, color='r', linestyle='--', alpha=0.5) plt.xscale('log') plt.tight_layout() plt.show() ``` ## Tested on M1 Pro <img width="1000" height="600" alt="Figure_1" src="https://github.com/user-attachments/assets/4a2402ec-3dc4-402d-8196-a0426906ca3d" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165232 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
fdab48a7c1
commit
ad67170c8b
@ -1370,6 +1370,7 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU: bmm_sparse_cpu
|
SparseCPU: bmm_sparse_cpu
|
||||||
SparseCUDA: bmm_sparse_cuda
|
SparseCUDA: bmm_sparse_cuda
|
||||||
|
SparseMPS: bmm_sparse_mps
|
||||||
NestedTensorCPU: bmm_nested
|
NestedTensorCPU: bmm_nested
|
||||||
NestedTensorCUDA: bmm_nested_cuda
|
NestedTensorCUDA: bmm_nested_cuda
|
||||||
tags: core
|
tags: core
|
||||||
@ -1385,6 +1386,7 @@
|
|||||||
MTIA: bmm_out_mtia
|
MTIA: bmm_out_mtia
|
||||||
SparseCPU: bmm_out_sparse_cpu
|
SparseCPU: bmm_out_sparse_cpu
|
||||||
SparseCUDA: bmm_out_sparse_cuda
|
SparseCUDA: bmm_out_sparse_cuda
|
||||||
|
SparseMPS: bmm_out_sparse_mps
|
||||||
SparseCsrCUDA: bmm_out_sparse_csr_cuda
|
SparseCsrCUDA: bmm_out_sparse_csr_cuda
|
||||||
|
|
||||||
- func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
|
- func: bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor
|
||||||
@ -4173,7 +4175,7 @@
|
|||||||
structured_delegate: mm.out
|
structured_delegate: mm.out
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU, SparseCUDA: _sparse_mm
|
SparseCPU, SparseCUDA, SparseMPS: _sparse_mm
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _sparse_csr_mm
|
||||||
tags: core
|
tags: core
|
||||||
|
|
||||||
@ -7112,6 +7114,7 @@
|
|||||||
MTIA: addmm_out_mtia
|
MTIA: addmm_out_mtia
|
||||||
SparseCPU: addmm_out_sparse_dense_cpu
|
SparseCPU: addmm_out_sparse_dense_cpu
|
||||||
SparseCUDA: addmm_out_sparse_dense_cuda
|
SparseCUDA: addmm_out_sparse_dense_cuda
|
||||||
|
SparseMPS: addmm_out_sparse_dense_mps
|
||||||
SparseCsrCPU: addmm_out_sparse_compressed_cpu
|
SparseCsrCPU: addmm_out_sparse_compressed_cpu
|
||||||
SparseCsrCUDA: addmm_out_sparse_compressed_cuda
|
SparseCsrCUDA: addmm_out_sparse_compressed_cuda
|
||||||
|
|
||||||
@ -7121,6 +7124,7 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU: addmm_sparse_dense_cpu
|
SparseCPU: addmm_sparse_dense_cpu
|
||||||
SparseCUDA: addmm_sparse_dense_cuda
|
SparseCUDA: addmm_sparse_dense_cuda
|
||||||
|
SparseMPS: addmm_sparse_dense_mps
|
||||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense
|
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: addmm_sparse_compressed_dense
|
||||||
tags: core
|
tags: core
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/native/SparseTensorUtils.h>
|
#include <ATen/native/SparseTensorUtils.h>
|
||||||
|
#include <ATen/ExpandUtils.h>
|
||||||
#include <ATen/native/mps/OperationUtils.h>
|
#include <ATen/native/mps/OperationUtils.h>
|
||||||
#include <ATen/native/sparse/SparseStubs.h>
|
#include <ATen/native/sparse/SparseStubs.h>
|
||||||
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
|
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
|
||||||
@ -18,6 +19,8 @@
|
|||||||
#include <ATen/ops/ones_like.h>
|
#include <ATen/ops/ones_like.h>
|
||||||
#include <ATen/ops/argsort.h>
|
#include <ATen/ops/argsort.h>
|
||||||
#include <ATen/ops/result_type.h>
|
#include <ATen/ops/result_type.h>
|
||||||
|
#include <ATen/ops/bmm_native.h>
|
||||||
|
#include <ATen/ops/addmm_native.h>
|
||||||
#include <ATen/ops/copy_sparse_to_sparse.h>
|
#include <ATen/ops/copy_sparse_to_sparse.h>
|
||||||
#include <ATen/ops/mul.h>
|
#include <ATen/ops/mul.h>
|
||||||
#endif
|
#endif
|
||||||
@ -33,6 +36,305 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
|||||||
#include <ATen/native/mps/Mul_metallib.h>
|
#include <ATen/native/mps/Mul_metallib.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
static Tensor& s_addmm_out_sparse_dense_mps(
|
||||||
|
Tensor& r,
|
||||||
|
const Tensor& t,
|
||||||
|
const SparseTensor& sparse_,
|
||||||
|
const Tensor& dense,
|
||||||
|
const Scalar& beta,
|
||||||
|
const Scalar& alpha) {
|
||||||
|
TORCH_CHECK(sparse_.sparse_dim() == 2, "addmm: sparse_dim must be 2, got ", sparse_.sparse_dim());
|
||||||
|
TORCH_CHECK(sparse_.dense_dim() == 0, "addmm: sparse values must be 0-dense-dim, got ", sparse_.dense_dim());
|
||||||
|
TORCH_CHECK(dense.dim() == 2, "addmm: 'dense' must be 2D, got ", dense.dim());
|
||||||
|
TORCH_CHECK(t.dim() == 2, "addmm: 't' must be 2D, got ", t.dim());
|
||||||
|
|
||||||
|
const int64_t I = sparse_.size(0);
|
||||||
|
const int64_t J = sparse_.size(1);
|
||||||
|
const int64_t K = dense.size(1);
|
||||||
|
|
||||||
|
TORCH_CHECK(dense.size(0) == J,
|
||||||
|
"addmm: dense (mat2) dim0 must be ", J, ", got ", dense.size(0));
|
||||||
|
TORCH_CHECK(t.size(0) == I && t.size(1) == K,
|
||||||
|
"addmm: 't' shape must be (", I, ", ", K, "), got (", t.size(0), ", ", t.size(1), ")");
|
||||||
|
|
||||||
|
r.resize_({I, K});
|
||||||
|
|
||||||
|
auto sparse = sparse_.coalesce();
|
||||||
|
const int64_t nnz = sparse._nnz();
|
||||||
|
|
||||||
|
if (nnz == 0 || I == 0 || K == 0) {
|
||||||
|
at::mul_out(r, t, beta);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto v_dtype = sparse._values().scalar_type();
|
||||||
|
const auto d_dtype = dense.scalar_type();
|
||||||
|
const auto t_dtype = t.scalar_type();
|
||||||
|
auto compute_dtype = c10::promoteTypes(c10::promoteTypes(v_dtype, d_dtype), t_dtype);
|
||||||
|
|
||||||
|
TORCH_CHECK(canCast(compute_dtype, r.scalar_type()),
|
||||||
|
"Can't convert computed type ", compute_dtype, " to output ", r.scalar_type());
|
||||||
|
|
||||||
|
auto indices2d = sparse._indices().contiguous();
|
||||||
|
auto values = sparse._values().to(compute_dtype);
|
||||||
|
auto dense_c = dense.to(compute_dtype).contiguous();
|
||||||
|
auto t_c = t.to(compute_dtype).contiguous();
|
||||||
|
|
||||||
|
const bool out_needs_cast = (r.scalar_type() != compute_dtype) || !r.is_contiguous();
|
||||||
|
Tensor out_buf = out_needs_cast
|
||||||
|
? at::empty({I, K}, r.options().dtype(compute_dtype))
|
||||||
|
: r;
|
||||||
|
auto out_contig = out_buf.contiguous();
|
||||||
|
|
||||||
|
auto device = r.device();
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
|
||||||
|
const float alpha_f = alpha.to<float>();
|
||||||
|
const float beta_f = beta.to<float>();
|
||||||
|
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
const std::string func = "spmm_addmm_coo_" + mps::scalarToMetalTypeString(values);
|
||||||
|
auto pso = lib.getPipelineStateForFunc(func);
|
||||||
|
auto enc = stream->commandEncoder();
|
||||||
|
[enc setComputePipelineState:pso];
|
||||||
|
|
||||||
|
const uint32_t tew = pso.threadExecutionWidth;
|
||||||
|
const uint32_t gridX = static_cast<uint32_t>(K);
|
||||||
|
const uint32_t gridZ = static_cast<uint32_t>(I);
|
||||||
|
const uint32_t tgW = std::min<uint32_t>(gridX, tew);
|
||||||
|
|
||||||
|
MTLSize grid = MTLSizeMake(gridX, 1, gridZ);
|
||||||
|
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||||
|
|
||||||
|
mtl_setArgs(enc,
|
||||||
|
indices2d,
|
||||||
|
values,
|
||||||
|
dense_c,
|
||||||
|
t_c,
|
||||||
|
out_contig,
|
||||||
|
std::array<uint32_t, 3>{static_cast<uint32_t>(I),
|
||||||
|
static_cast<uint32_t>(J),
|
||||||
|
static_cast<uint32_t>(K)},
|
||||||
|
std::array<float, 2>{alpha_f, beta_f},
|
||||||
|
static_cast<uint32_t>(nnz));
|
||||||
|
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (out_needs_cast) {
|
||||||
|
r.copy_(out_contig.to(r.scalar_type()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void build_batch_ptr_mps(
|
||||||
|
const Tensor& indices_dim0,
|
||||||
|
int64_t B,
|
||||||
|
Tensor& batch_ptr
|
||||||
|
) {
|
||||||
|
// Builds an array of pointers which point to each batches elements. Example:
|
||||||
|
// idx_b = [0, 0, 0, 1, 1, 2, 2, 2, 2] // 9 non-zero elements
|
||||||
|
// └─────┘ └──┘ └─────────┘
|
||||||
|
// batch 0 batch 1 batch 2
|
||||||
|
// batch_ptr = [0, 3, 5, 9]
|
||||||
|
// │ │ │ └─ end of batch 2 (total nnz)
|
||||||
|
// │ │ └──── batch 2 starts at index 5
|
||||||
|
// │ └─────── batch 1 starts at index 3
|
||||||
|
// └────────── batch 0 starts at index 0
|
||||||
|
TORCH_CHECK(indices_dim0.is_mps() && batch_ptr.is_mps(), "MPS device expected");
|
||||||
|
auto device = indices_dim0.device();
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
|
||||||
|
const int64_t nnz = indices_dim0.numel();
|
||||||
|
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pso = lib.getPipelineStateForFunc("build_batch_ptr_from_sorted_batches");
|
||||||
|
auto enc = stream->commandEncoder();
|
||||||
|
[enc setComputePipelineState:pso];
|
||||||
|
|
||||||
|
const uint32_t tew = pso.threadExecutionWidth;
|
||||||
|
const uint32_t Q = static_cast<uint32_t>(B + 1);
|
||||||
|
const uint32_t tgW = std::min<uint32_t>(Q, tew);
|
||||||
|
MTLSize grid = MTLSizeMake(Q, 1, 1);
|
||||||
|
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||||
|
|
||||||
|
mtl_setArgs(enc,
|
||||||
|
indices_dim0,
|
||||||
|
batch_ptr,
|
||||||
|
std::array<uint32_t, 2>{static_cast<uint32_t>(nnz),
|
||||||
|
static_cast<uint32_t>(B)});
|
||||||
|
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void build_row_ptr_per_batch_mps(
|
||||||
|
const Tensor& rows,
|
||||||
|
const Tensor& batch_ptr,
|
||||||
|
int64_t B,
|
||||||
|
int64_t I,
|
||||||
|
Tensor& row_ptr
|
||||||
|
) {
|
||||||
|
// Build per-batch CSR-style row pointer arrays from row indices sorted by batch
|
||||||
|
// Given:
|
||||||
|
// rows: 1-D array of length nnz with row ids in [0, I), sorted within each batch
|
||||||
|
// batch_ptr: length B+1, where [batch_ptr[b], batch_ptr[b+1]) is the subrange for batch b
|
||||||
|
// Produces:
|
||||||
|
// - row_ptr: shape [B, I+1]
|
||||||
|
//
|
||||||
|
// Example (B = 2, I = 4):
|
||||||
|
// rows = [0, 0, 1, 3, 0, 2, 2] // 7 non-zero elements
|
||||||
|
// └─── batch 0 ──┘ └─ batch 1 ─┘
|
||||||
|
// batch_ptr = [0, 4, 7]
|
||||||
|
// │ │ └─ end of batch 1 (total nnz)
|
||||||
|
// │ └──── end of batch 0/start of batch 1
|
||||||
|
// └─────── start of batch 0
|
||||||
|
//
|
||||||
|
// per-batch row pointers (I+1 entries each):
|
||||||
|
// row_ptr[0] = [0, 2, 3, 3, 4]
|
||||||
|
// row_ptr[1] = [0, 1, 1, 3, 3]
|
||||||
|
// laid out in memory: [0, 2, 3, 3, 4, 0, 1, 1, 3, 3]
|
||||||
|
TORCH_CHECK(rows.is_mps() && batch_ptr.is_mps() && row_ptr.is_mps(), "MPS device expected");
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pso = lib.getPipelineStateForFunc("build_row_ptr_from_sorted_rows_by_batch");
|
||||||
|
auto enc = stream->commandEncoder();
|
||||||
|
[enc setComputePipelineState:pso];
|
||||||
|
|
||||||
|
const uint32_t tew = pso.threadExecutionWidth;
|
||||||
|
const uint32_t Qx = static_cast<uint32_t>(I + 1);
|
||||||
|
const uint32_t Qy = static_cast<uint32_t>(B);
|
||||||
|
const uint32_t tgW = std::min<uint32_t>(Qx, tew);
|
||||||
|
|
||||||
|
MTLSize grid = MTLSizeMake(Qx, Qy, 1);
|
||||||
|
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||||
|
|
||||||
|
mtl_setArgs(enc,
|
||||||
|
rows,
|
||||||
|
batch_ptr,
|
||||||
|
row_ptr,
|
||||||
|
std::array<uint32_t, 2>{static_cast<uint32_t>(I),
|
||||||
|
static_cast<uint32_t>(B)});
|
||||||
|
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor& bmm_out_sparse_mps(const SparseTensor& self_, const Tensor& mat2_, Tensor& result_) {
|
||||||
|
TORCH_CHECK(result_.is_mps(), "bmm_sparse: expected 'out' to be MPS, got ", result_.device());
|
||||||
|
TORCH_CHECK(self_.is_mps(), "bmm_sparse: expected 'self' to be MPS, got ", self_.device());
|
||||||
|
TORCH_CHECK(mat2_.is_mps(), "bmm_sparse: expected 'mat2' to be MPS, got ", mat2_.device());
|
||||||
|
|
||||||
|
TORCH_CHECK(self_.dense_dim() == 0, "bmm_sparse: Tensor 'self' must have 0 dense dims, but has ", self_.dense_dim());
|
||||||
|
TORCH_CHECK(self_.sparse_dim() == 3, "bmm_sparse: Tensor 'self' must have 3 sparse dims, but has ", self_.sparse_dim());
|
||||||
|
TORCH_CHECK(mat2_.dim() == 3, "bmm_sparse: Tensor 'mat2' must have 3 dims, but has ", mat2_.dim());
|
||||||
|
|
||||||
|
TORCH_CHECK(self_.size(0) == mat2_.size(0), "bmm_sparse: 'self.size(0)' and 'mat2.size(0)' must match");
|
||||||
|
TORCH_CHECK(self_.size(2) == mat2_.size(1), "bmm_sparse: 'self.size(2)' and 'mat2.size(1)' must match");
|
||||||
|
|
||||||
|
const int64_t B = self_.size(0);
|
||||||
|
const int64_t I = self_.size(1);
|
||||||
|
const int64_t J = self_.size(2);
|
||||||
|
const int64_t K = mat2_.size(2);
|
||||||
|
|
||||||
|
auto self = self_.coalesce();
|
||||||
|
const int64_t nnz = self._nnz();
|
||||||
|
if (nnz == 0) {
|
||||||
|
return result_.zero_();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto computeDtype = at::kFloat;
|
||||||
|
|
||||||
|
auto indices = self._indices();
|
||||||
|
auto values = self._values();
|
||||||
|
|
||||||
|
auto values_c = values.scalar_type() == computeDtype ? values : values.to(computeDtype);
|
||||||
|
auto mat2_c = mat2_.scalar_type() == computeDtype ? mat2_ : mat2_.to(computeDtype);
|
||||||
|
auto mat2_contig = mat2_c.contiguous();
|
||||||
|
|
||||||
|
auto idx_b = indices.select(0, 0).contiguous();
|
||||||
|
auto idx_i = indices.select(0, 1).contiguous();
|
||||||
|
auto idx_j = indices.select(0, 2).contiguous();
|
||||||
|
|
||||||
|
// builds an array of pointers of where the batch_idx's pointer starts and ends
|
||||||
|
// look in function for better explanation
|
||||||
|
auto batch_ptr = at::empty({B + 1}, at::device(result_.device()).dtype(kLong));
|
||||||
|
build_batch_ptr_mps(idx_b, B, batch_ptr);
|
||||||
|
// build row_ptr per batch: for each (b, i) get [start, end) into rows/cols/vals
|
||||||
|
auto row_ptr = at::empty({B * (I + 1)}, at::device(result_.device()).dtype(kLong));
|
||||||
|
build_row_ptr_per_batch_mps(idx_i, batch_ptr, B, I, row_ptr);
|
||||||
|
|
||||||
|
const bool out_needs_cast = (result_.scalar_type() != computeDtype) || !result_.is_contiguous();
|
||||||
|
Tensor out_buf = out_needs_cast
|
||||||
|
? at::empty({B, I, K}, result_.options().dtype(computeDtype))
|
||||||
|
: result_;
|
||||||
|
auto out_contig = out_buf.contiguous();
|
||||||
|
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pso = lib.getPipelineStateForFunc("spmm_bmm_coo_rows_grouped_" + mps::scalarToMetalTypeString(values));
|
||||||
|
auto enc = stream->commandEncoder();
|
||||||
|
[enc setComputePipelineState:pso];
|
||||||
|
|
||||||
|
const uint32_t tew = pso.threadExecutionWidth;
|
||||||
|
const uint32_t tgW = std::min<uint32_t>((uint32_t)K, tew);
|
||||||
|
|
||||||
|
// One threadgroup per (row i, batch b), lanes cover K
|
||||||
|
MTLSize grid = MTLSizeMake(tgW, (uint32_t)I, (uint32_t)B);
|
||||||
|
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||||
|
|
||||||
|
mtl_setArgs(enc,
|
||||||
|
idx_i,
|
||||||
|
idx_j,
|
||||||
|
values_c,
|
||||||
|
mat2_contig,
|
||||||
|
out_contig,
|
||||||
|
row_ptr,
|
||||||
|
std::array<uint32_t, 4>{(uint32_t)B, (uint32_t)I, (uint32_t)J, (uint32_t)K});
|
||||||
|
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (out_needs_cast) {
|
||||||
|
result_.copy_(out_contig.to(result_.scalar_type()));
|
||||||
|
}
|
||||||
|
return result_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor bmm_sparse_mps(const Tensor& self, const Tensor& mat2) {
|
||||||
|
Tensor result = at::zeros({self.size(0), self.size(1), mat2.size(2)}, mat2.options());
|
||||||
|
return bmm_out_sparse_mps(self, mat2, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor& addmm_out_sparse_dense_mps(
|
||||||
|
const Tensor& self,
|
||||||
|
const SparseTensor& mat1,
|
||||||
|
const Tensor& mat2,
|
||||||
|
const Scalar& beta,
|
||||||
|
const Scalar& alpha,
|
||||||
|
Tensor& result) {
|
||||||
|
c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
|
||||||
|
return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor addmm_sparse_dense_mps(
|
||||||
|
const Tensor& self,
|
||||||
|
const SparseTensor& mat1,
|
||||||
|
const Tensor& mat2,
|
||||||
|
const Scalar& beta,
|
||||||
|
const Scalar& alpha
|
||||||
|
) {
|
||||||
|
c10::MaybeOwned<Tensor> b_self = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out");
|
||||||
|
Tensor result = at::empty({0}, self.options());
|
||||||
|
return s_addmm_out_sparse_dense_mps(result, *b_self, mat1, mat2, beta, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
static SparseTensor& mul_out_dense_sparse_mps(
|
static SparseTensor& mul_out_dense_sparse_mps(
|
||||||
const Tensor& dense,
|
const Tensor& dense,
|
||||||
const Tensor& sparse,
|
const Tensor& sparse,
|
||||||
|
@ -1,10 +1,105 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
#include <c10/metal/indexing.h>
|
#include <c10/metal/indexing.h>
|
||||||
|
#include <c10/metal/utils.h>
|
||||||
|
using namespace c10::metal;
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
inline uint lower_bound_i64(device const long* arr, uint lo, uint hi, long key) {
|
||||||
|
uint l = lo, r = hi;
|
||||||
|
while (l < r) {
|
||||||
|
uint m = (l + r) >> 1;
|
||||||
|
long v = arr[m];
|
||||||
|
if (v < key) {
|
||||||
|
l = m + 1;
|
||||||
|
} else {
|
||||||
|
r = m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T> struct MulAccum { using type = float; };
|
inline uint upper_bound_i64(device const long* arr, uint lo, uint hi, long key) {
|
||||||
template <> struct MulAccum<float2> { using type = float2; };
|
uint l = lo, r = hi;
|
||||||
|
while (l < r) {
|
||||||
|
uint m = (l + r) >> 1;
|
||||||
|
long v = arr[m];
|
||||||
|
if (v <= key) {
|
||||||
|
l = m + 1;
|
||||||
|
} else {
|
||||||
|
r = m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void build_row_ptr_from_sorted_rows_by_batch(
|
||||||
|
device const long* rows [[buffer(0)]],
|
||||||
|
device const long* batch_ptr [[buffer(1)]],
|
||||||
|
device long* row_ptr [[buffer(2)]],
|
||||||
|
constant uint2& dims [[buffer(3)]],
|
||||||
|
uint3 tid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
const uint I = dims.x;
|
||||||
|
const uint B = dims.y;
|
||||||
|
|
||||||
|
const uint i = tid.x;
|
||||||
|
const uint b = tid.y;
|
||||||
|
|
||||||
|
if (b >= B || i > I) return;
|
||||||
|
|
||||||
|
const uint base = (uint)batch_ptr[b];
|
||||||
|
const uint lim = (uint)batch_ptr[b + 1];
|
||||||
|
|
||||||
|
const ulong out_base = (ulong)b * (ulong)(I + 1);
|
||||||
|
|
||||||
|
if (i == I) {
|
||||||
|
row_ptr[out_base + (ulong)I] = (long)lim;
|
||||||
|
} else {
|
||||||
|
const long key = (long)i;
|
||||||
|
const uint pos = lower_bound_i64(rows, base, lim, key);
|
||||||
|
row_ptr[out_base + (ulong)i] = (long)pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void spmm_bmm_coo_rows_grouped(
|
||||||
|
device const long* rows [[buffer(0)]],
|
||||||
|
device const long* cols [[buffer(1)]],
|
||||||
|
device const T* vals [[buffer(2)]],
|
||||||
|
device const T* dense [[buffer(3)]],
|
||||||
|
device T* out [[buffer(4)]],
|
||||||
|
device const long* row_ptr [[buffer(5)]],
|
||||||
|
constant uint4& dims [[buffer(6)]],
|
||||||
|
uint3 tid [[thread_position_in_grid]],
|
||||||
|
uint3 ltid [[thread_position_in_threadgroup]],
|
||||||
|
uint3 tptg [[threads_per_threadgroup]])
|
||||||
|
{
|
||||||
|
const uint B = dims.x;
|
||||||
|
const uint I = dims.y;
|
||||||
|
const uint J = dims.z;
|
||||||
|
const uint K = dims.w;
|
||||||
|
|
||||||
|
const uint b = tid.z;
|
||||||
|
const uint i = tid.y;
|
||||||
|
const uint lane = ltid.x;
|
||||||
|
const uint tgW = tptg.x;
|
||||||
|
|
||||||
|
const ulong rp_base = (ulong)b * (ulong)(I + 1);
|
||||||
|
const uint start = (uint)row_ptr[rp_base + (ulong)i];
|
||||||
|
const uint end = (uint)row_ptr[rp_base + (ulong)i + 1];
|
||||||
|
|
||||||
|
for (uint k = lane; k < K; k += tgW) {
|
||||||
|
auto acc = static_cast<accum_t<T>>(T(0));
|
||||||
|
for (uint p = start; p < end; ++p) {
|
||||||
|
const uint c = (uint)cols[p];
|
||||||
|
const auto v = static_cast<accum_t<T>>(vals[p]);
|
||||||
|
const uint d_off = ((b * J) + c) * K + k;
|
||||||
|
const auto d = static_cast<accum_t<T>>(dense[d_off]);
|
||||||
|
acc += mul(v, d);
|
||||||
|
}
|
||||||
|
const uint y_off = ((b * I) + i) * K + k;
|
||||||
|
out[y_off] = static_cast<T>(acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
kernel void dense_sparse_mul_kernel(
|
kernel void dense_sparse_mul_kernel(
|
||||||
@ -32,10 +127,9 @@ kernel void dense_sparse_mul_kernel(
|
|||||||
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
|
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
|
||||||
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
|
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
|
||||||
|
|
||||||
using accum_t = typename MulAccum<T>::type;
|
const auto a = static_cast<accum_t<T>>(values[val_idx]);
|
||||||
const accum_t a = static_cast<accum_t>(values[val_idx]);
|
const auto b = static_cast<accum_t<T>>(dense[dense_idx]);
|
||||||
const accum_t b = static_cast<accum_t>(dense[dense_idx]);
|
out_values[val_idx] = static_cast<T>(mul(a, b));
|
||||||
out_values[val_idx] = static_cast<T>(a * b);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void intersect_binary_search(
|
kernel void intersect_binary_search(
|
||||||
@ -120,6 +214,76 @@ kernel void fused_gather_mul_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
kernel void build_batch_ptr_from_sorted_batches(
|
||||||
|
device const long* batches [[buffer(0)]],
|
||||||
|
device long* batch_ptr [[buffer(1)]],
|
||||||
|
constant uint2& nnz_B [[buffer(2)]],
|
||||||
|
uint3 tid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
uint b = tid.x;
|
||||||
|
uint nnz = nnz_B.x;
|
||||||
|
uint batch = nnz_B.y;
|
||||||
|
|
||||||
|
if (b == batch) {
|
||||||
|
batch_ptr[b] = (long)nnz;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint lo = 0;
|
||||||
|
uint hi = nnz;
|
||||||
|
long key = (long)b;
|
||||||
|
while (lo < hi) {
|
||||||
|
uint mid = (lo + hi) >> 1;
|
||||||
|
long v = batches[mid];
|
||||||
|
if (v < key) lo = mid + 1;
|
||||||
|
else hi = mid;
|
||||||
|
}
|
||||||
|
batch_ptr[b] = (long)lo;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void spmm_addmm_coo(
|
||||||
|
device const long* indices2d [[buffer(0)]],
|
||||||
|
device const T* vals [[buffer(1)]],
|
||||||
|
device const T* dense [[buffer(2)]],
|
||||||
|
device const T* t_in [[buffer(3)]],
|
||||||
|
device T* out [[buffer(4)]],
|
||||||
|
constant uint3& dims [[buffer(5)]],
|
||||||
|
constant float2& alpha_beta [[buffer(6)]],
|
||||||
|
constant uint& nnz [[buffer(7)]],
|
||||||
|
uint3 tid [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
const uint K = dims.z;
|
||||||
|
const uint k = tid.x;
|
||||||
|
const uint i = tid.z;
|
||||||
|
const float alpha = alpha_beta.x;
|
||||||
|
const float beta = alpha_beta.y;
|
||||||
|
|
||||||
|
device const long* rows = indices2d;
|
||||||
|
device const long* cols = indices2d + nnz;
|
||||||
|
|
||||||
|
const uint start = lower_bound_i64(rows, 0u, nnz, (long)i);
|
||||||
|
const uint end = upper_bound_i64(rows, 0u, nnz, (long)i);
|
||||||
|
|
||||||
|
// accumulator is float for scalar/half/bfloat and float2 for float2
|
||||||
|
auto acc = static_cast<accum_t<T>>(T(0));
|
||||||
|
|
||||||
|
for (uint p = start; p < end; ++p) {
|
||||||
|
const uint c = (uint)cols[p];
|
||||||
|
const auto v = static_cast<accum_t<T>>(vals[p]);
|
||||||
|
const uint dense_off = c * K + k;
|
||||||
|
const auto d = static_cast<accum_t<T>>(dense[dense_off]);
|
||||||
|
acc += mul(v, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint off = i * K + k;
|
||||||
|
const auto base = (beta != 0.0f) ? (static_cast<accum_t<T>>(t_in[off]) * beta) : static_cast<accum_t<T>>(T(0));
|
||||||
|
const auto y = base + alpha * acc;
|
||||||
|
out[off] = static_cast<T>(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE) \
|
#define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE) \
|
||||||
template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void \
|
template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void \
|
||||||
dense_sparse_mul_kernel<DTYPE>( \
|
dense_sparse_mul_kernel<DTYPE>( \
|
||||||
@ -151,6 +315,36 @@ INSTANTIATE_DENSE_SPARSE_MUL(float2);
|
|||||||
constant uint2& dims_output [[buffer(8)]], \
|
constant uint2& dims_output [[buffer(8)]], \
|
||||||
uint3 gid [[thread_position_in_grid]]);
|
uint3 gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
INSTANTIATE_FUSED_GATHER_MUL(float);
|
INSTANTIATE_FOR_FLOAT_TYPES(INSTANTIATE_FUSED_GATHER_MUL);
|
||||||
INSTANTIATE_FUSED_GATHER_MUL(half);
|
|
||||||
INSTANTIATE_FUSED_GATHER_MUL(bfloat);
|
|
||||||
|
#define INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED(DTYPE) \
|
||||||
|
template [[host_name("spmm_bmm_coo_rows_grouped_" #DTYPE)]] kernel void \
|
||||||
|
spmm_bmm_coo_rows_grouped<DTYPE>( \
|
||||||
|
device const long* rows [[buffer(0)]], \
|
||||||
|
device const long* cols [[buffer(1)]], \
|
||||||
|
device const DTYPE* vals [[buffer(2)]], \
|
||||||
|
device const DTYPE* dense [[buffer(3)]], \
|
||||||
|
device DTYPE* out [[buffer(4)]], \
|
||||||
|
device const long* row_ptr [[buffer(5)]], \
|
||||||
|
constant uint4& dims [[buffer(6)]], \
|
||||||
|
uint3 tid [[thread_position_in_grid]], \
|
||||||
|
uint3 ltid [[thread_position_in_threadgroup]], \
|
||||||
|
uint3 tptg [[threads_per_threadgroup]]);
|
||||||
|
|
||||||
|
INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_BMM_COO_ROWS_GROUPED);
|
||||||
|
|
||||||
|
#define INSTANTIATE_SPMM_ADDMM_COO(DTYPE) \
|
||||||
|
template [[host_name("spmm_addmm_coo_" #DTYPE)]] kernel void \
|
||||||
|
spmm_addmm_coo<DTYPE>( \
|
||||||
|
device const long* indices2d [[buffer(0)]], \
|
||||||
|
device const DTYPE* vals [[buffer(1)]], \
|
||||||
|
device const DTYPE* dense [[buffer(2)]], \
|
||||||
|
device const DTYPE* t_in [[buffer(3)]], \
|
||||||
|
device DTYPE* out [[buffer(4)]], \
|
||||||
|
constant uint3& dims [[buffer(5)]], \
|
||||||
|
constant float2& alpha_beta [[buffer(6)]], \
|
||||||
|
constant uint& nnz [[buffer(7)]], \
|
||||||
|
uint3 tid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
INSTANTIATE_FOR_ALL_TYPES(INSTANTIATE_SPMM_ADDMM_COO);
|
||||||
|
@ -328,5 +328,21 @@ struct pair {
|
|||||||
T2 second;
|
T2 second;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#define INSTANTIATE_FOR_ALL_TYPES(MACRO) \
|
||||||
|
MACRO(float); \
|
||||||
|
MACRO(half); \
|
||||||
|
MACRO(bfloat); \
|
||||||
|
MACRO(float2); \
|
||||||
|
MACRO(long); \
|
||||||
|
MACRO(char); \
|
||||||
|
MACRO(uchar); \
|
||||||
|
MACRO(short); \
|
||||||
|
MACRO(int);
|
||||||
|
|
||||||
|
#define INSTANTIATE_FOR_FLOAT_TYPES(MACRO) \
|
||||||
|
MACRO(float); \
|
||||||
|
MACRO(half); \
|
||||||
|
MACRO(bfloat);
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
@ -1421,7 +1421,6 @@ class TestSparse(TestSparseBase):
|
|||||||
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
|
"bmm sparse-dense CUDA is not yet supported in Windows, at least up to CUDA 10.1"
|
||||||
)
|
)
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@expectedFailureMPS
|
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
@dtypesIfMPS(torch.float32)
|
@dtypesIfMPS(torch.float32)
|
||||||
def test_bmm(self, device, dtype, coalesced):
|
def test_bmm(self, device, dtype, coalesced):
|
||||||
@ -1633,7 +1632,6 @@ class TestSparse(TestSparseBase):
|
|||||||
self.assertEqual(self.safeToDense(res), self.safeToDense(true_result))
|
self.assertEqual(self.safeToDense(res), self.safeToDense(true_result))
|
||||||
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@expectedFailureMPS
|
|
||||||
@precisionOverride({torch.bfloat16: 5e-2, torch.float16: 5e-2})
|
@precisionOverride({torch.bfloat16: 5e-2, torch.float16: 5e-2})
|
||||||
@dtypes(torch.double, torch.cdouble, torch.bfloat16, torch.float16)
|
@dtypes(torch.double, torch.cdouble, torch.bfloat16, torch.float16)
|
||||||
@dtypesIfMPS(torch.float32, torch.complex64, torch.bfloat16, torch.float16)
|
@dtypesIfMPS(torch.float32, torch.complex64, torch.bfloat16, torch.float16)
|
||||||
@ -1724,7 +1722,6 @@ class TestSparse(TestSparseBase):
|
|||||||
# test_shape(2, 3, [2, 2, 0])
|
# test_shape(2, 3, [2, 2, 0])
|
||||||
|
|
||||||
@coalescedonoff
|
@coalescedonoff
|
||||||
@expectedFailureMPS
|
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
@dtypesIfMPS(torch.float32)
|
@dtypesIfMPS(torch.float32)
|
||||||
def test_dsmm(self, device, dtype, coalesced):
|
def test_dsmm(self, device, dtype, coalesced):
|
||||||
|
Reference in New Issue
Block a user