[MPS] mps sparse mul op implementation (#162349)

Implements mps sparse mul operation as well as enables other operations such as:
1. copy_
2. div
3. sum
4. floor
5. power
6. sub
7. floor_divide

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162349
Approved by: https://github.com/pearu, https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Isalia20
2025-09-09 15:45:34 +00:00
committed by PyTorch MergeBot
parent be3b8d2ec9
commit 3ea6868049
4 changed files with 483 additions and 25 deletions

View File

@ -1798,7 +1798,7 @@
device_guard: False
dispatch:
MkldnnCPU: copy_mkldnn_
SparseCPU, SparseCUDA: copy_sparse_wrapper_
SparseCPU, SparseCUDA, SparseMPS: copy_sparse_wrapper_
CompositeExplicitAutograd: copy_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: copy_sparse_compressed_
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: copy_nested_
@ -2160,7 +2160,7 @@
variants: function, method
structured_delegate: div.out
dispatch:
SparseCPU, SparseCUDA: div_sparse
SparseCPU, SparseCUDA, SparseMPS: div_sparse
ZeroTensor: div_zerotensor
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_div_Tensor
tags: [core, pointwise]
@ -2170,7 +2170,7 @@
variants: method
structured_delegate: div.out
dispatch:
SparseCPU, SparseCUDA: div_sparse_
SparseCPU, SparseCUDA, SparseMPS: div_sparse_
tags: pointwise
- func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
@ -2179,7 +2179,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS, MTIA: div_out
SparseCPU, SparseCUDA: div_out_sparse_zerodim
SparseCPU, SparseCUDA, SparseMPS: div_out_sparse_zerodim
tags: pointwise
- func: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
@ -2187,7 +2187,7 @@
variants: function, method
structured_delegate: div.out_mode
dispatch:
SparseCPU, SparseCUDA: div_sparse
SparseCPU, SparseCUDA, SparseMPS: div_sparse
tags: [core, pointwise]
- func: div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)
@ -2195,7 +2195,7 @@
variants: method
structured_delegate: div.out_mode
dispatch:
SparseCPU, SparseCUDA: div_sparse_
SparseCPU, SparseCUDA, SparseMPS: div_sparse_
tags: pointwise
- func: div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
@ -2204,7 +2204,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA, MPS: div_out_mode
SparseCPU, SparseCUDA: div_out_sparse_zerodim
SparseCPU, SparseCUDA, SparseMPS: div_out_sparse_zerodim
tags: pointwise
# For C++ only, until we have conversion from C++ numbers to Tensor
@ -2768,20 +2768,20 @@
variants: function, method
dispatch:
CPU, CUDA, MPS, MTIA: floor_divide
SparseCPU, SparseCUDA: floor_divide_sparse
SparseCPU, SparseCUDA, SparseMPS: floor_divide_sparse
- func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: method
dispatch:
CPU, CUDA, MPS: floor_divide_
SparseCPU, SparseCUDA: floor_divide_sparse_
SparseCPU, SparseCUDA, SparseMPS: floor_divide_sparse_
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA, MPS: floor_divide_out
SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim
SparseCPU, SparseCUDA, SparseMPS: floor_divide_out_sparse_zerodim
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
device_check: NoCheck # TensorIterator
@ -4273,7 +4273,7 @@
structured_delegate: mul.out
variants: function, method
dispatch:
SparseCPU, SparseCUDA: mul_sparse
SparseCPU, SparseCUDA, SparseMPS: mul_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_sparse_csr
MkldnnCPU: mkldnn_mul
ZeroTensor: mul_zerotensor
@ -4285,7 +4285,7 @@
structured_delegate: mul.out
variants: method
dispatch:
SparseCPU, SparseCUDA: mul_sparse_
SparseCPU, SparseCUDA, SparseMPS: mul_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_sparse_csr_
MkldnnCPU: mkldnn_mul_
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_mul__Tensor
@ -4299,6 +4299,7 @@
CPU, CUDA, MPS, MTIA: mul_out
SparseCPU: mul_out_sparse_cpu
SparseCUDA: mul_out_sparse_cuda
SparseMPS: mul_out_sparse_mps
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_out_sparse_csr
MkldnnCPU: mkldnn_mul_out
tags: pointwise
@ -5848,7 +5849,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: sum
SparseCPU, SparseCUDA, SparseMeta: sum_coo
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sum_coo
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr
autogen: sum.out
@ -5859,7 +5860,7 @@
variants: function, method
dispatch:
NestedTensorCPU: NestedTensor_sum_dim_CPU
SparseCPU, SparseCUDA: sum_sparse_coo
SparseCPU, SparseCUDA, SparseMPS: sum_sparse_coo
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed
tags: core
@ -6975,7 +6976,7 @@
CPU, CUDA: sub_out
MPS: sub_out_mps
MTIA: sub_out_mtia
SparseCPU, SparseCUDA: sub_out_sparse
SparseCPU, SparseCUDA, SparseMPS: sub_out_sparse
tags: pointwise
- func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
@ -6983,7 +6984,7 @@
variants: function, method
structured_delegate: sub.out
dispatch:
SparseCPU, SparseCUDA: sub_sparse
SparseCPU, SparseCUDA, SparseMPS: sub_sparse
ZeroTensor: sub_zerotensor
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sub_Tensor
tags: [core, pointwise]
@ -6993,7 +6994,7 @@
variants: method
structured_delegate: sub.out
dispatch:
SparseCPU, SparseCUDA: sub_sparse_
SparseCPU, SparseCUDA, SparseMPS: sub_sparse_
tags: pointwise
# For C++ only, until we have conversion from C++ numbers to Tensor
@ -10342,7 +10343,7 @@
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: pow_Tensor_Scalar_out
SparseCPU, SparseCUDA: pow_out_sparse_scalar
SparseCPU, SparseCUDA, SparseMPS: pow_out_sparse_scalar
MPS: pow_tensor_scalar_out_mps
tags: pointwise
@ -10351,7 +10352,7 @@
structured_delegate: pow.Tensor_Scalar_out
variants: function, method
dispatch:
SparseCPU, SparseCUDA: pow_sparse_scalar
SparseCPU, SparseCUDA, SparseMPS: pow_sparse_scalar
tags: [core, pointwise]
- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)

View File

@ -10,6 +10,7 @@
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/add_native.h>
#include <ATen/ops/mul_native.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/zeros_native.h>
#include <ATen/ops/result_type.h>
@ -20,10 +21,268 @@
namespace at::native {
using namespace at::sparse;
using namespace mps;
Tensor& add_out_dense_sparse_mps(Tensor& out, const Tensor& dense, const SparseTensor& sparse, const Scalar& alpha);
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/Mul_metallib.h>
#endif
Tensor& add_out_dense_sparse_mps(
static SparseTensor& mul_out_dense_sparse_mps(
const Tensor& dense,
const Tensor& sparse,
SparseTensor& out) {
TORCH_CHECK(sparse.is_sparse(), "mul: expected 'sparse' to be sparse COO");
TORCH_CHECK(sparse.is_mps(), "mul: expected 'sparse' to be MPS, got ", sparse.device());
TORCH_CHECK(out.is_mps(), "mul: expected 'out' to be MPS, got ", out.device());
const bool scalar_like = (dense.dim() == 0) || (dense.numel() == 1);
TORCH_CHECK(dense.is_mps() || scalar_like,
"mul: expected 'dense' to be MPS or scalar-like, got ", dense.device());
const int64_t nnz = sparse._nnz();
out.resize_as_(sparse);
auto commonDtype = at::result_type(dense, sparse);
TORCH_CHECK(canCast(commonDtype, out.scalar_type()),
"Can't convert result type ", commonDtype, " to output ", out.scalar_type());
auto indices = sparse._indices().contiguous();
auto values = sparse._values().to(commonDtype).contiguous();
if (nnz == 0) {
auto empty_vals = values.narrow(0, 0, 0);
alias_into_sparse(out,
indices.narrow(1, 0, 0),
(out.scalar_type() == commonDtype) ? empty_vals
: empty_vals.to(out.scalar_type()));
out._coalesced_(sparse.is_coalesced());
return out;
}
if (scalar_like) {
auto scalar = dense;
if (dense.numel() == 1 && dense.dim() > 0) {
scalar = dense.view({});
}
scalar = scalar.to(values.options());
auto out_vals = values.mul(scalar);
if (out.scalar_type() != commonDtype) {
out_vals = out_vals.to(out.scalar_type());
}
alias_into_sparse(out, indices, out_vals);
out._coalesced_(sparse.is_coalesced());
return out;
}
TORCH_CHECK(dense.sizes().equals(sparse.sizes()),
"mul(dense, sparse): sizes must match exactly (no broadcasting): ",
dense.sizes(), " vs ", sparse.sizes());
const int64_t nDimI = sparse.sparse_dim();
const int64_t nDim = dense.dim();
TORCH_CHECK(
nDimI <= nDim,
"mul(dense, sparse): sparse_dim=", nDimI, " exceeds dense.dim()=", nDim);
// Prepare shapes
int64_t view_rows = 1, view_cols = 1;
for (int64_t i = 0; i < nDimI; ++i) view_rows *= sparse.size(i);
for (int64_t i = nDimI; i < nDim; ++i) view_cols *= sparse.size(i);
auto dense_mps = dense.to(commonDtype).contiguous().reshape({view_rows, view_cols});
auto out_vals = at::empty_like(values, values.options());
const uint32_t u_view_cols = static_cast<uint32_t>(view_cols);
const uint32_t u_nnz = static_cast<uint32_t>(nnz);
const uint32_t u_nDimI = static_cast<uint32_t>(nDimI);
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("dense_sparse_mul_kernel_" + mps::scalarToMetalTypeString(values));
auto computeEncoder = stream->commandEncoder();
[computeEncoder setComputePipelineState:pso];
const uint32_t gridWidth = u_view_cols;
const uint32_t gridDepth = u_nnz;
MTLSize gridSize = MTLSizeMake(gridWidth, 1, gridDepth);
const uint32_t maxThreadsPerGroup = pso.maxTotalThreadsPerThreadgroup;
const uint32_t tew = pso.threadExecutionWidth;
uint32_t tgWidth = std::min(gridWidth, tew);
MTLSize threadgroupSize = MTLSizeMake(tgWidth, 1, 1);
mtl_setArgs(
computeEncoder,
dense_mps,
values,
out_vals,
indices,
sparse.sizes(),
u_nnz,
u_nDimI,
u_view_cols
);
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize];
}
});
Tensor final_vals = out_vals;
if (out.scalar_type() != commonDtype) {
final_vals = final_vals.to(out.scalar_type());
}
alias_into_sparse(out, indices, final_vals);
out._coalesced_(sparse.is_coalesced());
return out;
}
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
// Dense x sparse fallback (keep dense first)
if (!t_.is_sparse() || !src_.is_sparse()) {
const Tensor& dense = t_.is_sparse() ? src_ : t_;
const Tensor& sparse = t_.is_sparse() ? t_ : src_;
return mul_out_dense_sparse_mps(dense, sparse, r_);
}
TORCH_CHECK(t_.is_mps(), "mul: expected 'self' to be MPS, but got ", t_.device());
TORCH_CHECK(src_.is_mps(), "mul: expected 'other' to be MPS, but got ", src_.device());
TORCH_CHECK(t_.sparse_dim() == src_.sparse_dim(),
"mul(sparse, sparse): must have same sparse_dim, got ",
t_.sparse_dim(), " vs ", src_.sparse_dim());
TORCH_CHECK(t_.sizes().equals(src_.sizes()),
"mul(sparse, sparse): sizes must match exactly (no broadcasting).");
// Coalesce and early-exit on structurally empty operands
auto lhs = t_.coalesce();
auto rhs = src_.coalesce();
const int64_t lhs_nnz = lhs._nnz();
const int64_t rhs_nnz = rhs._nnz();
if (!lhs_nnz || !rhs_nnz) {
r_.resize_as_(lhs);
return r_.zero_();
}
// dtype checks and promotion
auto commonDtype = at::result_type(lhs, rhs);
TORCH_CHECK(canCast(commonDtype, r_.scalar_type()),
"Can't convert result type ", commonDtype, " to output ", r_.scalar_type());
const int64_t nDimI = lhs.sparse_dim();
// nDimI == 0, at most one structural entry
if (nDimI == 0) {
r_.resize_as_(lhs);
const bool has = (lhs_nnz && rhs_nnz);
auto out_indices = lhs._indices().narrow(1, 0, has ? 1 : 0);
Tensor lhs_vals = lhs._values().to(commonDtype);
Tensor rhs_vals = rhs._values().to(commonDtype);
lhs_vals = lhs_vals.narrow(0, 0, has ? 1 : 0);
rhs_vals = rhs_vals.narrow(0, 0, has ? 1 : 0);
Tensor out_values = lhs_vals.mul(rhs_vals);
if (r_.scalar_type() != commonDtype) {
out_values = out_values.to(r_.scalar_type());
}
alias_into_sparse(r_, out_indices, out_values);
r_._coalesced_(true);
return r_;
}
// General path, intersect keys, then gather + multiply on GPU
const auto device = r_.device();
auto stream = getCurrentMPSStream();
auto lhs_indices = lhs._indices();
auto rhs_indices = rhs._indices();
auto lhs_values = lhs._values().to(commonDtype);
auto rhs_values = rhs._values().to(commonDtype);
// Flatten sparse indices to keys
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
// Intersect sorted keys (search the shorter in the longer)
const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
const int64_t lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
const int64_t lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
static_cast<uint32_t>(lenB), A_is_lhs);
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
}
});
const uint32_t M = counter.item<int32_t>(); // number of structural matches
r_.resize_as_(lhs);
auto out_indices = at::empty({nDimI, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
auto lhs_match = outA_idx.narrow(0, 0, M);
auto rhs_match = outB_idx.narrow(0, 0, M);
auto out_val_sizes = lhs_values.sizes().vec();
out_val_sizes[0] = static_cast<int64_t>(M);
auto out_values = at::empty(out_val_sizes, lhs_values.options());
const uint32_t cols = static_cast<uint32_t>(
lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc(
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
uint32_t tgW = std::min(cols, tew);
MTLSize grid = MTLSizeMake(cols, 1, M);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
lhs_values, rhs_values,
lhs_match, rhs_match,
lhs_indices, out_indices,
out_values, static_cast<uint32_t>(nDimI),
static_cast<uint32_t>(lhs_nnz),
static_cast<uint32_t>(M),
cols);
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
if (r_.scalar_type() != commonDtype) {
out_values = out_values.to(r_.scalar_type());
}
alias_into_sparse(r_, out_indices, out_values);
r_._coalesced_(true);
return r_;
}
static Tensor& add_out_dense_sparse_mps(
Tensor& out,
const Tensor& dense,
const SparseTensor& sparse,

View File

@ -0,0 +1,198 @@
#include <metal_stdlib>
#include <c10/metal/indexing.h>
using namespace metal;
template <typename T>
kernel void dense_sparse_mul_kernel(
device const T* dense [[buffer(0)]],
device const T* values [[buffer(1)]],
device T* out_values [[buffer(2)]],
device const long* indices [[buffer(3)]],
device const long* sizes [[buffer(4)]],
constant uint& nnz [[buffer(5)]],
constant uint& ndim_i [[buffer(6)]],
constant uint& view_cols [[buffer(7)]],
uint3 gid [[thread_position_in_grid]])
{
uint col = gid.x;
uint i = gid.z;
long key = 0;
for (uint d = 0; d < ndim_i; ++d) {
long idx_d = indices[(ulong)d * (ulong)nnz + (ulong)i];
const auto sz_d = sizes[d];
key = key * sz_d + idx_d;
}
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
const auto a = static_cast<float>(values[val_idx]);
const auto b = static_cast<float>(dense[dense_idx]);
out_values[val_idx] = static_cast<T>(a * b);
}
kernel void intersect_binary_search(
device const long* keysA [[buffer(0)]],
device const long* keysB [[buffer(1)]],
device long* outA_idx [[buffer(2)]],
device long* outB_idx [[buffer(3)]],
device atomic_uint* counter [[buffer(4)]],
constant uint& lenB [[buffer(5)]],
constant bool& A_is_lhs [[buffer(6)]],
uint3 tid_in_grid [[thread_position_in_grid]])
{
uint gid = tid_in_grid.x;
long key = keysA[gid];
// lower_bound in B
uint lo = 0;
uint hi = lenB;
while (lo < hi) {
uint mid = (lo + hi) >> 1;
long v = keysB[mid];
if (v < key) lo = mid + 1;
else hi = mid;
}
if (lo < lenB && keysB[lo] == key) {
uint pos = atomic_fetch_add_explicit(counter, 1u, memory_order_relaxed);
if (A_is_lhs) {
outA_idx[pos] = (long)gid;
outB_idx[pos] = (long)lo;
} else {
outA_idx[pos] = (long)lo;
outB_idx[pos] = (long)gid;
}
}
}
template <typename T>
kernel void fused_gather_mul_kernel(
device const T* lhs_vals [[buffer(0)]],
device const T* rhs_vals [[buffer(1)]],
device const long* lhs_sel [[buffer(2)]],
device const long* rhs_sel [[buffer(3)]],
device const long* lhs_indices [[buffer(4)]],
device long* out_indices [[buffer(5)]],
device T* out_vals [[buffer(6)]],
constant uint& nDimI [[buffer(7)]],
constant uint& L [[buffer(8)]],
constant uint& M [[buffer(9)]],
constant uint& view_cols [[buffer(10)]],
uint3 gid [[thread_position_in_grid]])
{
const uint col = gid.x;
const uint k = gid.z;
const long iL = lhs_sel[k];
const long iR = rhs_sel[k];
if (col < view_cols) {
const ulong offL = (ulong)iL * (ulong)view_cols + (ulong)col;
const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
const ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
const float a = (float)lhs_vals[offL];
const float b = (float)rhs_vals[offR];
out_vals[offO] = (T)(a * b);
}
// One thread per match copies the indices column
if (col == 0) {
const ulong uL = (ulong)L;
const ulong uM = (ulong)M;
const ulong src_col = (ulong)iL; // gather from lhs
for (uint d = 0; d < nDimI; ++d) {
const long v = lhs_indices[(ulong)d * uL + src_col];
out_indices[(ulong)d * uM + (ulong)k] = v;
}
}
}
kernel void gather_int64_columns(
device const long* in [[buffer(0)]],
device const long* sel [[buffer(1)]],
device long* out [[buffer(2)]],
constant uint& nDimI [[buffer(3)]],
constant uint& L [[buffer(4)]],
constant uint& M [[buffer(5)]],
uint g [[thread_position_in_grid]])
{
uint d32 = g / M;
uint k32 = g - d32 * M;
ulong d = (ulong)d32;
ulong k = (ulong)k32;
long src_col = sel[k];
long v = in[d * (ulong)L + (ulong)src_col];
out[d * (ulong)M + k] = v;
}
template <typename T>
kernel void pairwise_gather_mul_kernel(
device const T* lhs_vals [[buffer(0)]],
device const T* rhs_vals [[buffer(1)]],
device const long* lhs_sel [[buffer(2)]],
device const long* rhs_sel [[buffer(3)]],
device T* out_vals [[buffer(4)]],
constant uint& M [[buffer(5)]],
constant uint& view_cols [[buffer(6)]],
uint3 gid [[thread_position_in_grid]])
{
uint col = gid.x;
uint k = gid.z;
long iL = lhs_sel[k];
long iR = rhs_sel[k];
ulong offL = (ulong)iL * (ulong)view_cols + (ulong)col;
ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
float a = (float)lhs_vals[offL];
float b = (float)rhs_vals[offR];
out_vals[offO] = (T)(a * b);
}
#define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE) \
template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void \
dense_sparse_mul_kernel<DTYPE>( \
device const DTYPE* dense [[buffer(0)]], \
device const DTYPE* values [[buffer(1)]], \
device DTYPE* out_values [[buffer(2)]], \
device const long* indices [[buffer(3)]], \
device const long* sizes [[buffer(4)]], \
constant uint& nnz [[buffer(5)]], \
constant uint& nDimI [[buffer(6)]], \
constant uint& view_cols [[buffer(7)]], \
uint3 gid [[thread_position_in_grid]]);
INSTANTIATE_DENSE_SPARSE_MUL(float);
INSTANTIATE_DENSE_SPARSE_MUL(half);
INSTANTIATE_DENSE_SPARSE_MUL(bfloat);
#define INSTANTIATE_FUSED_GATHER_MUL(DTYPE) \
template [[host_name("fused_gather_mul_kernel_" #DTYPE)]] kernel void \
fused_gather_mul_kernel<DTYPE>( \
device const DTYPE* lhs_vals [[buffer(0)]], \
device const DTYPE* rhs_vals [[buffer(1)]], \
device const long* lhs_sel [[buffer(2)]], \
device const long* rhs_sel [[buffer(3)]], \
device const long* lhs_indices [[buffer(4)]], \
device long* out_indices [[buffer(5)]], \
device DTYPE* out_vals [[buffer(6)]], \
constant uint& nDimI [[buffer(7)]], \
constant uint& L [[buffer(8)]], \
constant uint& M [[buffer(9)]], \
constant uint& view_cols [[buffer(10)]], \
uint3 gid [[thread_position_in_grid]]);
INSTANTIATE_FUSED_GATHER_MUL(float);
INSTANTIATE_FUSED_GATHER_MUL(half);
INSTANTIATE_FUSED_GATHER_MUL(bfloat);

View File

@ -1108,8 +1108,8 @@ class TestSparse(TestSparseBase):
test_shape(2, 20, [3, 17, 19, 5])
test_shape(2, 20, [3, 17, 19, 0])
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
def test_add_sub_nnz(self, device, dtype):
# nnz should not grow unbounded (gh-34964)
x = torch.randn(10, dtype=dtype, device=device).to_sparse()
@ -1687,8 +1687,8 @@ class TestSparse(TestSparseBase):
test_shape(7, 8, 9, 20, True)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
@gradcheck_semantics()
def test_sparse_mul(self, device, dtype, coalesced, gradcheck):
@ -1868,8 +1868,8 @@ class TestSparse(TestSparseBase):
x.norm(**kwargs)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
@unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
def test_sparse_sum(self, device, dtype, coalesced):
@ -1933,7 +1933,6 @@ class TestSparse(TestSparseBase):
S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
run_tests(S.requires_grad_(True), test_dim)
@expectedFailureMPS
def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced):
shape = shape_i + (shape_v)
x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced)
@ -2011,6 +2010,7 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
def test_basic_ops(self, device, dtype, coalesced):
def _test_basic_ops():