mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] mps sparse mul op implementation (#162349)
Implements mps sparse mul operation as well as enables other operations such as: 1. copy_ 2. div 3. sum 4. floor 5. power 6. sub 7. floor_divide Pull Request resolved: https://github.com/pytorch/pytorch/pull/162349 Approved by: https://github.com/pearu, https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
be3b8d2ec9
commit
3ea6868049
@ -1798,7 +1798,7 @@
|
||||
device_guard: False
|
||||
dispatch:
|
||||
MkldnnCPU: copy_mkldnn_
|
||||
SparseCPU, SparseCUDA: copy_sparse_wrapper_
|
||||
SparseCPU, SparseCUDA, SparseMPS: copy_sparse_wrapper_
|
||||
CompositeExplicitAutograd: copy_
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: copy_sparse_compressed_
|
||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: copy_nested_
|
||||
@ -2160,7 +2160,7 @@
|
||||
variants: function, method
|
||||
structured_delegate: div.out
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: div_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: div_sparse
|
||||
ZeroTensor: div_zerotensor
|
||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_div_Tensor
|
||||
tags: [core, pointwise]
|
||||
@ -2170,7 +2170,7 @@
|
||||
variants: method
|
||||
structured_delegate: div.out
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: div_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS: div_sparse_
|
||||
tags: pointwise
|
||||
|
||||
- func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
@ -2179,7 +2179,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS, MTIA: div_out
|
||||
SparseCPU, SparseCUDA: div_out_sparse_zerodim
|
||||
SparseCPU, SparseCUDA, SparseMPS: div_out_sparse_zerodim
|
||||
tags: pointwise
|
||||
|
||||
- func: div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor
|
||||
@ -2187,7 +2187,7 @@
|
||||
variants: function, method
|
||||
structured_delegate: div.out_mode
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: div_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: div_sparse
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)
|
||||
@ -2195,7 +2195,7 @@
|
||||
variants: method
|
||||
structured_delegate: div.out_mode
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: div_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS: div_sparse_
|
||||
tags: pointwise
|
||||
|
||||
- func: div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)
|
||||
@ -2204,7 +2204,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: div_out_mode
|
||||
SparseCPU, SparseCUDA: div_out_sparse_zerodim
|
||||
SparseCPU, SparseCUDA, SparseMPS: div_out_sparse_zerodim
|
||||
tags: pointwise
|
||||
|
||||
# For C++ only, until we have conversion from C++ numbers to Tensor
|
||||
@ -2768,20 +2768,20 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CPU, CUDA, MPS, MTIA: floor_divide
|
||||
SparseCPU, SparseCUDA: floor_divide_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: floor_divide_sparse
|
||||
|
||||
- func: floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
variants: method
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: floor_divide_
|
||||
SparseCPU, SparseCUDA: floor_divide_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS: floor_divide_sparse_
|
||||
|
||||
- func: floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
device_check: NoCheck # TensorIterator
|
||||
dispatch:
|
||||
CPU, CUDA, MPS: floor_divide_out
|
||||
SparseCPU, SparseCUDA: floor_divide_out_sparse_zerodim
|
||||
SparseCPU, SparseCUDA, SparseMPS: floor_divide_out_sparse_zerodim
|
||||
|
||||
- func: floor_divide.Scalar(Tensor self, Scalar other) -> Tensor
|
||||
device_check: NoCheck # TensorIterator
|
||||
@ -4273,7 +4273,7 @@
|
||||
structured_delegate: mul.out
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: mul_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: mul_sparse
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_sparse_csr
|
||||
MkldnnCPU: mkldnn_mul
|
||||
ZeroTensor: mul_zerotensor
|
||||
@ -4285,7 +4285,7 @@
|
||||
structured_delegate: mul.out
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: mul_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS: mul_sparse_
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_sparse_csr_
|
||||
MkldnnCPU: mkldnn_mul_
|
||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_mul__Tensor
|
||||
@ -4299,6 +4299,7 @@
|
||||
CPU, CUDA, MPS, MTIA: mul_out
|
||||
SparseCPU: mul_out_sparse_cpu
|
||||
SparseCUDA: mul_out_sparse_cuda
|
||||
SparseMPS: mul_out_sparse_mps
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: mul_out_sparse_csr
|
||||
MkldnnCPU: mkldnn_mul_out
|
||||
tags: pointwise
|
||||
@ -5848,7 +5849,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: sum
|
||||
SparseCPU, SparseCUDA, SparseMeta: sum_coo
|
||||
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: sum_coo
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr
|
||||
autogen: sum.out
|
||||
|
||||
@ -5859,7 +5860,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
NestedTensorCPU: NestedTensor_sum_dim_CPU
|
||||
SparseCPU, SparseCUDA: sum_sparse_coo
|
||||
SparseCPU, SparseCUDA, SparseMPS: sum_sparse_coo
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_sparse_compressed
|
||||
tags: core
|
||||
|
||||
@ -6975,7 +6976,7 @@
|
||||
CPU, CUDA: sub_out
|
||||
MPS: sub_out_mps
|
||||
MTIA: sub_out_mtia
|
||||
SparseCPU, SparseCUDA: sub_out_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: sub_out_sparse
|
||||
tags: pointwise
|
||||
|
||||
- func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
|
||||
@ -6983,7 +6984,7 @@
|
||||
variants: function, method
|
||||
structured_delegate: sub.out
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sub_sparse
|
||||
SparseCPU, SparseCUDA, SparseMPS: sub_sparse
|
||||
ZeroTensor: sub_zerotensor
|
||||
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_sub_Tensor
|
||||
tags: [core, pointwise]
|
||||
@ -6993,7 +6994,7 @@
|
||||
variants: method
|
||||
structured_delegate: sub.out
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sub_sparse_
|
||||
SparseCPU, SparseCUDA, SparseMPS: sub_sparse_
|
||||
tags: pointwise
|
||||
# For C++ only, until we have conversion from C++ numbers to Tensor
|
||||
|
||||
@ -10342,7 +10343,7 @@
|
||||
structured_inherits: TensorIteratorBase
|
||||
dispatch:
|
||||
CPU, CUDA: pow_Tensor_Scalar_out
|
||||
SparseCPU, SparseCUDA: pow_out_sparse_scalar
|
||||
SparseCPU, SparseCUDA, SparseMPS: pow_out_sparse_scalar
|
||||
MPS: pow_tensor_scalar_out_mps
|
||||
tags: pointwise
|
||||
|
||||
@ -10351,7 +10352,7 @@
|
||||
structured_delegate: pow.Tensor_Scalar_out
|
||||
variants: function, method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: pow_sparse_scalar
|
||||
SparseCPU, SparseCUDA, SparseMPS: pow_sparse_scalar
|
||||
tags: [core, pointwise]
|
||||
|
||||
- func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/add_native.h>
|
||||
#include <ATen/ops/mul_native.h>
|
||||
#include <ATen/ops/empty_native.h>
|
||||
#include <ATen/ops/zeros_native.h>
|
||||
#include <ATen/ops/result_type.h>
|
||||
@ -20,10 +21,268 @@
|
||||
namespace at::native {
|
||||
|
||||
using namespace at::sparse;
|
||||
using namespace mps;
|
||||
|
||||
Tensor& add_out_dense_sparse_mps(Tensor& out, const Tensor& dense, const SparseTensor& sparse, const Scalar& alpha);
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/Mul_metallib.h>
|
||||
#endif
|
||||
|
||||
Tensor& add_out_dense_sparse_mps(
|
||||
static SparseTensor& mul_out_dense_sparse_mps(
|
||||
const Tensor& dense,
|
||||
const Tensor& sparse,
|
||||
SparseTensor& out) {
|
||||
|
||||
TORCH_CHECK(sparse.is_sparse(), "mul: expected 'sparse' to be sparse COO");
|
||||
TORCH_CHECK(sparse.is_mps(), "mul: expected 'sparse' to be MPS, got ", sparse.device());
|
||||
TORCH_CHECK(out.is_mps(), "mul: expected 'out' to be MPS, got ", out.device());
|
||||
|
||||
const bool scalar_like = (dense.dim() == 0) || (dense.numel() == 1);
|
||||
TORCH_CHECK(dense.is_mps() || scalar_like,
|
||||
"mul: expected 'dense' to be MPS or scalar-like, got ", dense.device());
|
||||
|
||||
const int64_t nnz = sparse._nnz();
|
||||
out.resize_as_(sparse);
|
||||
|
||||
auto commonDtype = at::result_type(dense, sparse);
|
||||
TORCH_CHECK(canCast(commonDtype, out.scalar_type()),
|
||||
"Can't convert result type ", commonDtype, " to output ", out.scalar_type());
|
||||
|
||||
auto indices = sparse._indices().contiguous();
|
||||
auto values = sparse._values().to(commonDtype).contiguous();
|
||||
|
||||
if (nnz == 0) {
|
||||
auto empty_vals = values.narrow(0, 0, 0);
|
||||
alias_into_sparse(out,
|
||||
indices.narrow(1, 0, 0),
|
||||
(out.scalar_type() == commonDtype) ? empty_vals
|
||||
: empty_vals.to(out.scalar_type()));
|
||||
out._coalesced_(sparse.is_coalesced());
|
||||
return out;
|
||||
}
|
||||
|
||||
if (scalar_like) {
|
||||
auto scalar = dense;
|
||||
if (dense.numel() == 1 && dense.dim() > 0) {
|
||||
scalar = dense.view({});
|
||||
}
|
||||
scalar = scalar.to(values.options());
|
||||
auto out_vals = values.mul(scalar);
|
||||
if (out.scalar_type() != commonDtype) {
|
||||
out_vals = out_vals.to(out.scalar_type());
|
||||
}
|
||||
|
||||
alias_into_sparse(out, indices, out_vals);
|
||||
out._coalesced_(sparse.is_coalesced());
|
||||
return out;
|
||||
}
|
||||
|
||||
TORCH_CHECK(dense.sizes().equals(sparse.sizes()),
|
||||
"mul(dense, sparse): sizes must match exactly (no broadcasting): ",
|
||||
dense.sizes(), " vs ", sparse.sizes());
|
||||
|
||||
const int64_t nDimI = sparse.sparse_dim();
|
||||
const int64_t nDim = dense.dim();
|
||||
TORCH_CHECK(
|
||||
nDimI <= nDim,
|
||||
"mul(dense, sparse): sparse_dim=", nDimI, " exceeds dense.dim()=", nDim);
|
||||
|
||||
// Prepare shapes
|
||||
int64_t view_rows = 1, view_cols = 1;
|
||||
for (int64_t i = 0; i < nDimI; ++i) view_rows *= sparse.size(i);
|
||||
for (int64_t i = nDimI; i < nDim; ++i) view_cols *= sparse.size(i);
|
||||
|
||||
auto dense_mps = dense.to(commonDtype).contiguous().reshape({view_rows, view_cols});
|
||||
auto out_vals = at::empty_like(values, values.options());
|
||||
|
||||
const uint32_t u_view_cols = static_cast<uint32_t>(view_cols);
|
||||
const uint32_t u_nnz = static_cast<uint32_t>(nnz);
|
||||
const uint32_t u_nDimI = static_cast<uint32_t>(nDimI);
|
||||
|
||||
auto stream = getCurrentMPSStream();
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("dense_sparse_mul_kernel_" + mps::scalarToMetalTypeString(values));
|
||||
auto computeEncoder = stream->commandEncoder();
|
||||
[computeEncoder setComputePipelineState:pso];
|
||||
|
||||
const uint32_t gridWidth = u_view_cols;
|
||||
const uint32_t gridDepth = u_nnz;
|
||||
MTLSize gridSize = MTLSizeMake(gridWidth, 1, gridDepth);
|
||||
|
||||
const uint32_t maxThreadsPerGroup = pso.maxTotalThreadsPerThreadgroup;
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
uint32_t tgWidth = std::min(gridWidth, tew);
|
||||
MTLSize threadgroupSize = MTLSizeMake(tgWidth, 1, 1);
|
||||
|
||||
mtl_setArgs(
|
||||
computeEncoder,
|
||||
dense_mps,
|
||||
values,
|
||||
out_vals,
|
||||
indices,
|
||||
sparse.sizes(),
|
||||
u_nnz,
|
||||
u_nDimI,
|
||||
u_view_cols
|
||||
);
|
||||
|
||||
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize];
|
||||
}
|
||||
});
|
||||
|
||||
Tensor final_vals = out_vals;
|
||||
if (out.scalar_type() != commonDtype) {
|
||||
final_vals = final_vals.to(out.scalar_type());
|
||||
}
|
||||
|
||||
alias_into_sparse(out, indices, final_vals);
|
||||
out._coalesced_(sparse.is_coalesced());
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTensor& r_) {
|
||||
TORCH_CHECK(r_.is_mps(), "mul: expected 'out' to be MPS, but got ", r_.device());
|
||||
|
||||
// Dense x sparse fallback (keep dense first)
|
||||
if (!t_.is_sparse() || !src_.is_sparse()) {
|
||||
const Tensor& dense = t_.is_sparse() ? src_ : t_;
|
||||
const Tensor& sparse = t_.is_sparse() ? t_ : src_;
|
||||
return mul_out_dense_sparse_mps(dense, sparse, r_);
|
||||
}
|
||||
|
||||
TORCH_CHECK(t_.is_mps(), "mul: expected 'self' to be MPS, but got ", t_.device());
|
||||
TORCH_CHECK(src_.is_mps(), "mul: expected 'other' to be MPS, but got ", src_.device());
|
||||
TORCH_CHECK(t_.sparse_dim() == src_.sparse_dim(),
|
||||
"mul(sparse, sparse): must have same sparse_dim, got ",
|
||||
t_.sparse_dim(), " vs ", src_.sparse_dim());
|
||||
TORCH_CHECK(t_.sizes().equals(src_.sizes()),
|
||||
"mul(sparse, sparse): sizes must match exactly (no broadcasting).");
|
||||
|
||||
// Coalesce and early-exit on structurally empty operands
|
||||
auto lhs = t_.coalesce();
|
||||
auto rhs = src_.coalesce();
|
||||
const int64_t lhs_nnz = lhs._nnz();
|
||||
const int64_t rhs_nnz = rhs._nnz();
|
||||
if (!lhs_nnz || !rhs_nnz) {
|
||||
r_.resize_as_(lhs);
|
||||
return r_.zero_();
|
||||
}
|
||||
|
||||
// dtype checks and promotion
|
||||
auto commonDtype = at::result_type(lhs, rhs);
|
||||
TORCH_CHECK(canCast(commonDtype, r_.scalar_type()),
|
||||
"Can't convert result type ", commonDtype, " to output ", r_.scalar_type());
|
||||
|
||||
const int64_t nDimI = lhs.sparse_dim();
|
||||
|
||||
// nDimI == 0, at most one structural entry
|
||||
if (nDimI == 0) {
|
||||
r_.resize_as_(lhs);
|
||||
const bool has = (lhs_nnz && rhs_nnz);
|
||||
|
||||
auto out_indices = lhs._indices().narrow(1, 0, has ? 1 : 0);
|
||||
|
||||
Tensor lhs_vals = lhs._values().to(commonDtype);
|
||||
Tensor rhs_vals = rhs._values().to(commonDtype);
|
||||
lhs_vals = lhs_vals.narrow(0, 0, has ? 1 : 0);
|
||||
rhs_vals = rhs_vals.narrow(0, 0, has ? 1 : 0);
|
||||
|
||||
Tensor out_values = lhs_vals.mul(rhs_vals);
|
||||
if (r_.scalar_type() != commonDtype) {
|
||||
out_values = out_values.to(r_.scalar_type());
|
||||
}
|
||||
|
||||
alias_into_sparse(r_, out_indices, out_values);
|
||||
r_._coalesced_(true);
|
||||
return r_;
|
||||
}
|
||||
|
||||
// General path, intersect keys, then gather + multiply on GPU
|
||||
const auto device = r_.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto lhs_indices = lhs._indices();
|
||||
auto rhs_indices = rhs._indices();
|
||||
auto lhs_values = lhs._values().to(commonDtype);
|
||||
auto rhs_values = rhs._values().to(commonDtype);
|
||||
|
||||
// Flatten sparse indices to keys
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
|
||||
|
||||
// Intersect sorted keys (search the shorter in the longer)
|
||||
const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
const int64_t lenA = A_is_lhs ? lhs_nnz : rhs_nnz;
|
||||
const int64_t lenB = A_is_lhs ? rhs_nnz : lhs_nnz;
|
||||
auto A_keys = A_is_lhs ? lhs_keys : rhs_keys;
|
||||
auto B_keys = A_is_lhs ? rhs_keys : lhs_keys;
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_lhs);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const uint32_t M = counter.item<int32_t>(); // number of structural matches
|
||||
|
||||
r_.resize_as_(lhs);
|
||||
|
||||
auto out_indices = at::empty({nDimI, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
|
||||
auto lhs_match = outA_idx.narrow(0, 0, M);
|
||||
auto rhs_match = outB_idx.narrow(0, 0, M);
|
||||
auto out_val_sizes = lhs_values.sizes().vec();
|
||||
out_val_sizes[0] = static_cast<int64_t>(M);
|
||||
auto out_values = at::empty(out_val_sizes, lhs_values.options());
|
||||
|
||||
const uint32_t cols = static_cast<uint32_t>(
|
||||
lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
uint32_t tgW = std::min(cols, tew);
|
||||
MTLSize grid = MTLSizeMake(cols, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
|
||||
mtl_setArgs(enc,
|
||||
lhs_values, rhs_values,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values, static_cast<uint32_t>(nDimI),
|
||||
static_cast<uint32_t>(lhs_nnz),
|
||||
static_cast<uint32_t>(M),
|
||||
cols);
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
|
||||
if (r_.scalar_type() != commonDtype) {
|
||||
out_values = out_values.to(r_.scalar_type());
|
||||
}
|
||||
|
||||
alias_into_sparse(r_, out_indices, out_values);
|
||||
r_._coalesced_(true);
|
||||
return r_;
|
||||
}
|
||||
|
||||
static Tensor& add_out_dense_sparse_mps(
|
||||
Tensor& out,
|
||||
const Tensor& dense,
|
||||
const SparseTensor& sparse,
|
||||
|
198
aten/src/ATen/native/sparse/mps/kernels/Mul.metal
Normal file
198
aten/src/ATen/native/sparse/mps/kernels/Mul.metal
Normal file
@ -0,0 +1,198 @@
|
||||
#include <metal_stdlib>
|
||||
#include <c10/metal/indexing.h>
|
||||
using namespace metal;
|
||||
|
||||
|
||||
template <typename T>
|
||||
kernel void dense_sparse_mul_kernel(
|
||||
device const T* dense [[buffer(0)]],
|
||||
device const T* values [[buffer(1)]],
|
||||
device T* out_values [[buffer(2)]],
|
||||
device const long* indices [[buffer(3)]],
|
||||
device const long* sizes [[buffer(4)]],
|
||||
constant uint& nnz [[buffer(5)]],
|
||||
constant uint& ndim_i [[buffer(6)]],
|
||||
constant uint& view_cols [[buffer(7)]],
|
||||
uint3 gid [[thread_position_in_grid]])
|
||||
{
|
||||
uint col = gid.x;
|
||||
uint i = gid.z;
|
||||
|
||||
long key = 0;
|
||||
for (uint d = 0; d < ndim_i; ++d) {
|
||||
long idx_d = indices[(ulong)d * (ulong)nnz + (ulong)i];
|
||||
const auto sz_d = sizes[d];
|
||||
key = key * sz_d + idx_d;
|
||||
}
|
||||
|
||||
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
|
||||
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
|
||||
|
||||
const auto a = static_cast<float>(values[val_idx]);
|
||||
const auto b = static_cast<float>(dense[dense_idx]);
|
||||
out_values[val_idx] = static_cast<T>(a * b);
|
||||
}
|
||||
|
||||
kernel void intersect_binary_search(
|
||||
device const long* keysA [[buffer(0)]],
|
||||
device const long* keysB [[buffer(1)]],
|
||||
device long* outA_idx [[buffer(2)]],
|
||||
device long* outB_idx [[buffer(3)]],
|
||||
device atomic_uint* counter [[buffer(4)]],
|
||||
constant uint& lenB [[buffer(5)]],
|
||||
constant bool& A_is_lhs [[buffer(6)]],
|
||||
uint3 tid_in_grid [[thread_position_in_grid]])
|
||||
{
|
||||
uint gid = tid_in_grid.x;
|
||||
|
||||
long key = keysA[gid];
|
||||
|
||||
// lower_bound in B
|
||||
uint lo = 0;
|
||||
uint hi = lenB;
|
||||
while (lo < hi) {
|
||||
uint mid = (lo + hi) >> 1;
|
||||
long v = keysB[mid];
|
||||
if (v < key) lo = mid + 1;
|
||||
else hi = mid;
|
||||
}
|
||||
|
||||
if (lo < lenB && keysB[lo] == key) {
|
||||
uint pos = atomic_fetch_add_explicit(counter, 1u, memory_order_relaxed);
|
||||
if (A_is_lhs) {
|
||||
outA_idx[pos] = (long)gid;
|
||||
outB_idx[pos] = (long)lo;
|
||||
} else {
|
||||
outA_idx[pos] = (long)lo;
|
||||
outB_idx[pos] = (long)gid;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
kernel void fused_gather_mul_kernel(
|
||||
device const T* lhs_vals [[buffer(0)]],
|
||||
device const T* rhs_vals [[buffer(1)]],
|
||||
device const long* lhs_sel [[buffer(2)]],
|
||||
device const long* rhs_sel [[buffer(3)]],
|
||||
device const long* lhs_indices [[buffer(4)]],
|
||||
device long* out_indices [[buffer(5)]],
|
||||
device T* out_vals [[buffer(6)]],
|
||||
constant uint& nDimI [[buffer(7)]],
|
||||
constant uint& L [[buffer(8)]],
|
||||
constant uint& M [[buffer(9)]],
|
||||
constant uint& view_cols [[buffer(10)]],
|
||||
uint3 gid [[thread_position_in_grid]])
|
||||
{
|
||||
const uint col = gid.x;
|
||||
const uint k = gid.z;
|
||||
|
||||
const long iL = lhs_sel[k];
|
||||
const long iR = rhs_sel[k];
|
||||
|
||||
if (col < view_cols) {
|
||||
const ulong offL = (ulong)iL * (ulong)view_cols + (ulong)col;
|
||||
const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
|
||||
const ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
|
||||
|
||||
const float a = (float)lhs_vals[offL];
|
||||
const float b = (float)rhs_vals[offR];
|
||||
out_vals[offO] = (T)(a * b);
|
||||
}
|
||||
|
||||
// One thread per match copies the indices column
|
||||
if (col == 0) {
|
||||
const ulong uL = (ulong)L;
|
||||
const ulong uM = (ulong)M;
|
||||
const ulong src_col = (ulong)iL; // gather from lhs
|
||||
for (uint d = 0; d < nDimI; ++d) {
|
||||
const long v = lhs_indices[(ulong)d * uL + src_col];
|
||||
out_indices[(ulong)d * uM + (ulong)k] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void gather_int64_columns(
|
||||
device const long* in [[buffer(0)]],
|
||||
device const long* sel [[buffer(1)]],
|
||||
device long* out [[buffer(2)]],
|
||||
constant uint& nDimI [[buffer(3)]],
|
||||
constant uint& L [[buffer(4)]],
|
||||
constant uint& M [[buffer(5)]],
|
||||
uint g [[thread_position_in_grid]])
|
||||
{
|
||||
uint d32 = g / M;
|
||||
uint k32 = g - d32 * M;
|
||||
|
||||
ulong d = (ulong)d32;
|
||||
ulong k = (ulong)k32;
|
||||
|
||||
long src_col = sel[k];
|
||||
|
||||
long v = in[d * (ulong)L + (ulong)src_col];
|
||||
out[d * (ulong)M + k] = v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void pairwise_gather_mul_kernel(
|
||||
device const T* lhs_vals [[buffer(0)]],
|
||||
device const T* rhs_vals [[buffer(1)]],
|
||||
device const long* lhs_sel [[buffer(2)]],
|
||||
device const long* rhs_sel [[buffer(3)]],
|
||||
device T* out_vals [[buffer(4)]],
|
||||
constant uint& M [[buffer(5)]],
|
||||
constant uint& view_cols [[buffer(6)]],
|
||||
uint3 gid [[thread_position_in_grid]])
|
||||
{
|
||||
uint col = gid.x;
|
||||
uint k = gid.z;
|
||||
|
||||
long iL = lhs_sel[k];
|
||||
long iR = rhs_sel[k];
|
||||
|
||||
ulong offL = (ulong)iL * (ulong)view_cols + (ulong)col;
|
||||
ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
|
||||
ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
|
||||
|
||||
float a = (float)lhs_vals[offL];
|
||||
float b = (float)rhs_vals[offR];
|
||||
out_vals[offO] = (T)(a * b);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_DENSE_SPARSE_MUL(DTYPE) \
|
||||
template [[host_name("dense_sparse_mul_kernel_" #DTYPE)]] kernel void \
|
||||
dense_sparse_mul_kernel<DTYPE>( \
|
||||
device const DTYPE* dense [[buffer(0)]], \
|
||||
device const DTYPE* values [[buffer(1)]], \
|
||||
device DTYPE* out_values [[buffer(2)]], \
|
||||
device const long* indices [[buffer(3)]], \
|
||||
device const long* sizes [[buffer(4)]], \
|
||||
constant uint& nnz [[buffer(5)]], \
|
||||
constant uint& nDimI [[buffer(6)]], \
|
||||
constant uint& view_cols [[buffer(7)]], \
|
||||
uint3 gid [[thread_position_in_grid]]);
|
||||
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(float);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(half);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(bfloat);
|
||||
|
||||
#define INSTANTIATE_FUSED_GATHER_MUL(DTYPE) \
|
||||
template [[host_name("fused_gather_mul_kernel_" #DTYPE)]] kernel void \
|
||||
fused_gather_mul_kernel<DTYPE>( \
|
||||
device const DTYPE* lhs_vals [[buffer(0)]], \
|
||||
device const DTYPE* rhs_vals [[buffer(1)]], \
|
||||
device const long* lhs_sel [[buffer(2)]], \
|
||||
device const long* rhs_sel [[buffer(3)]], \
|
||||
device const long* lhs_indices [[buffer(4)]], \
|
||||
device long* out_indices [[buffer(5)]], \
|
||||
device DTYPE* out_vals [[buffer(6)]], \
|
||||
constant uint& nDimI [[buffer(7)]], \
|
||||
constant uint& L [[buffer(8)]], \
|
||||
constant uint& M [[buffer(9)]], \
|
||||
constant uint& view_cols [[buffer(10)]], \
|
||||
uint3 gid [[thread_position_in_grid]]);
|
||||
|
||||
INSTANTIATE_FUSED_GATHER_MUL(float);
|
||||
INSTANTIATE_FUSED_GATHER_MUL(half);
|
||||
INSTANTIATE_FUSED_GATHER_MUL(bfloat);
|
@ -1108,8 +1108,8 @@ class TestSparse(TestSparseBase):
|
||||
test_shape(2, 20, [3, 17, 19, 5])
|
||||
test_shape(2, 20, [3, 17, 19, 0])
|
||||
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_add_sub_nnz(self, device, dtype):
|
||||
# nnz should not grow unbounded (gh-34964)
|
||||
x = torch.randn(10, dtype=dtype, device=device).to_sparse()
|
||||
@ -1687,8 +1687,8 @@ class TestSparse(TestSparseBase):
|
||||
test_shape(7, 8, 9, 20, True)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error")
|
||||
@gradcheck_semantics()
|
||||
def test_sparse_mul(self, device, dtype, coalesced, gradcheck):
|
||||
@ -1868,8 +1868,8 @@ class TestSparse(TestSparseBase):
|
||||
x.norm(**kwargs)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
@unittest.skipIf(TEST_WITH_CROSSREF, "fallback triggers cuda device error")
|
||||
def test_sparse_sum(self, device, dtype, coalesced):
|
||||
|
||||
@ -1933,7 +1933,6 @@ class TestSparse(TestSparseBase):
|
||||
S = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
|
||||
run_tests(S.requires_grad_(True), test_dim)
|
||||
|
||||
@expectedFailureMPS
|
||||
def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, coalesced):
|
||||
shape = shape_i + (shape_v)
|
||||
x1, _, _ = self._gen_sparse(len(shape_i), nnz_x1, shape, dtype, device, coalesced)
|
||||
@ -2011,6 +2010,7 @@ class TestSparse(TestSparseBase):
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
def test_basic_ops(self, device, dtype, coalesced):
|
||||
|
||||
def _test_basic_ops():
|
||||
|
Reference in New Issue
Block a user