mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[MPS] sparse mask implementation (#165102)
sparse mask implementation Pull Request resolved: https://github.com/pytorch/pytorch/pull/165102 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
e6033f6efb
commit
8573574b32
@ -7384,7 +7384,7 @@
|
||||
- func: sparse_mask(Tensor self, Tensor mask) -> Tensor
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_mask
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_mask
|
||||
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_mask_sparse_compressed
|
||||
autogen: sparse_mask.out
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/SparseTensorUtils.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/sparse/SparseStubs.h>
|
||||
#include <ATen/native/sparse/SparseBinaryOpIntersectionCommon.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -13,6 +15,8 @@
|
||||
#include <ATen/ops/mul_native.h>
|
||||
#include <ATen/ops/empty_native.h>
|
||||
#include <ATen/ops/zeros_native.h>
|
||||
#include <ATen/ops/ones_like.h>
|
||||
#include <ATen/ops/argsort.h>
|
||||
#include <ATen/ops/result_type.h>
|
||||
#include <ATen/ops/copy_sparse_to_sparse.h>
|
||||
#include <ATen/ops/mul.h>
|
||||
@ -436,4 +440,137 @@ SparseTensor& add_out_sparse_mps(const SparseTensor& self,
|
||||
return out;
|
||||
}
|
||||
|
||||
using OptTensor = std::optional<Tensor>;
|
||||
|
||||
|
||||
static void sparse_mask_apply_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& src_in,
|
||||
const Tensor& mask_in,
|
||||
bool accumulate_matches,
|
||||
bool require_same_sizes,
|
||||
bool coalesce_mask) {
|
||||
TORCH_CHECK(src_in.is_sparse() && mask_in.is_sparse(),
|
||||
"sparse_mask: expected both inputs to be sparse COO");
|
||||
TORCH_CHECK(src_in.is_mps() && mask_in.is_mps(),
|
||||
"sparse_mask: expected tensors to be on MPS device");
|
||||
TORCH_CHECK(src_in.sparse_dim() == mask_in.sparse_dim(),
|
||||
"sparse_mask: sparse_dim mismatch: ", src_in.sparse_dim(), " vs ", mask_in.sparse_dim());
|
||||
if (require_same_sizes) {
|
||||
TORCH_CHECK(src_in.sizes().equals(mask_in.sizes()),
|
||||
"sparse_mask: sizes must match exactly (no broadcasting)");
|
||||
}
|
||||
auto src = src_in.coalesce();
|
||||
auto mask = coalesce_mask ? mask_in.coalesce() : mask_in;
|
||||
|
||||
const int64_t src_nnz = src._nnz();
|
||||
const int64_t mask_nnz = mask._nnz();
|
||||
const int64_t sd = src.sparse_dim();
|
||||
result.sparse_resize_(mask.sizes(), mask.sparse_dim(), mask.dense_dim());
|
||||
|
||||
auto commonDtype = at::result_type(src, mask);
|
||||
TORCH_CHECK(canCast(commonDtype, result.scalar_type()),
|
||||
"Can't convert result type ", commonDtype, " to output ", result.scalar_type());
|
||||
|
||||
if (mask_nnz == 0) {
|
||||
alias_into_sparse(
|
||||
result,
|
||||
mask._indices().narrow(1, 0, 0),
|
||||
at::empty({0}, result.options().dtype(result.scalar_type())));
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(sd > 0 || (src_nnz <= 1 && mask_nnz <= 1),
|
||||
"sparse_mask: invalid sparse_dim or nnz");
|
||||
|
||||
if (sd == 0) {
|
||||
auto out_indices = mask._indices().narrow(1, 0, 1);
|
||||
auto out_values = src_nnz
|
||||
? src._values().narrow(0, 0, 1).to(commonDtype)
|
||||
: at::zeros({1}, at::device(result.device()).dtype(commonDtype));
|
||||
alias_into_sparse(result, out_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
if (src_nnz == 0) {
|
||||
auto out_indices = mask._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype);
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
alias_into_sparse(result, out_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
return;
|
||||
}
|
||||
|
||||
auto mask_indices = mask._indices().contiguous();
|
||||
auto src_indices = src._indices().contiguous();
|
||||
auto src_values = src._values().to(commonDtype).contiguous();
|
||||
|
||||
auto mask_keys = flatten_indices(mask_indices, mask.sizes().slice(0, sd)).contiguous();
|
||||
auto src_keys = flatten_indices(src_indices, src.sizes().slice(0, sd)).contiguous();
|
||||
|
||||
const bool A_is_src = (src_nnz <= mask_nnz);
|
||||
const int64_t lenA = A_is_src ? src_nnz : mask_nnz;
|
||||
const int64_t lenB = A_is_src ? mask_nnz : src_nnz;
|
||||
auto A_keys = A_is_src ? src_keys : mask_keys;
|
||||
auto B_keys = A_is_src ? mask_keys : src_keys;
|
||||
|
||||
const auto device = result.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto outA_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto outB_idx = at::empty({lenA}, at::device(device).dtype(at::kLong));
|
||||
auto counter = at::zeros({1}, at::device(device).dtype(at::kInt));
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc("intersect_binary_search");
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
mtl_setArgs(enc, A_keys, B_keys, outA_idx, outB_idx, counter,
|
||||
static_cast<uint32_t>(lenB), A_is_src);
|
||||
mtl_dispatch1DJob(enc, pso, static_cast<uint32_t>(lenA));
|
||||
}
|
||||
});
|
||||
|
||||
const int64_t M = static_cast<int64_t>(counter.item<int32_t>());
|
||||
|
||||
auto out_val_sizes = src_values.sizes().vec();
|
||||
out_val_sizes[0] = mask_nnz;
|
||||
auto out_values = at::zeros(out_val_sizes, src_values.options());
|
||||
|
||||
if (M > 0) {
|
||||
auto src_match = outA_idx.narrow(0, 0, M);
|
||||
auto mask_match = outB_idx.narrow(0, 0, M);
|
||||
|
||||
auto src_rows = src_values.index_select(0, src_match);
|
||||
if (accumulate_matches) {
|
||||
out_values.index_add_(0, mask_match, src_rows);
|
||||
} else {
|
||||
out_values.index_copy_(0, mask_match, src_rows);
|
||||
}
|
||||
}
|
||||
|
||||
alias_into_sparse(result, mask_indices, out_values);
|
||||
result._coalesced_(mask.is_coalesced());
|
||||
}
|
||||
|
||||
static void sparse_mask_intersection_out_mps_kernel(
|
||||
Tensor& result,
|
||||
const Tensor& lhs,
|
||||
const Tensor& rhs,
|
||||
const OptTensor& = std::nullopt) {
|
||||
sparse_mask_apply_out_mps_kernel(
|
||||
result,
|
||||
/*src_in=*/lhs,
|
||||
/*mask_in=*/rhs,
|
||||
/*accumulate_matches=*/false,
|
||||
/*require_same_sizes=*/false,
|
||||
/*coalesce_mask=*/false);
|
||||
}
|
||||
|
||||
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
|
||||
} // namespace at::native
|
@ -3,6 +3,9 @@
|
||||
using namespace metal;
|
||||
|
||||
|
||||
template <typename T> struct MulAccum { using type = float; };
|
||||
template <> struct MulAccum<float2> { using type = float2; };
|
||||
|
||||
template <typename T>
|
||||
kernel void dense_sparse_mul_kernel(
|
||||
device const T* dense [[buffer(0)]],
|
||||
@ -29,8 +32,9 @@ kernel void dense_sparse_mul_kernel(
|
||||
ulong dense_idx = (ulong)key * (ulong)view_cols + (ulong)col;
|
||||
ulong val_idx = (ulong)i * (ulong)view_cols + (ulong)col;
|
||||
|
||||
const auto a = static_cast<float>(values[val_idx]);
|
||||
const auto b = static_cast<float>(dense[dense_idx]);
|
||||
using accum_t = typename MulAccum<T>::type;
|
||||
const accum_t a = static_cast<accum_t>(values[val_idx]);
|
||||
const accum_t b = static_cast<accum_t>(dense[dense_idx]);
|
||||
out_values[val_idx] = static_cast<T>(a * b);
|
||||
}
|
||||
|
||||
@ -130,6 +134,8 @@ kernel void fused_gather_mul_kernel(
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(float);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(half);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(bfloat);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(long);
|
||||
INSTANTIATE_DENSE_SPARSE_MUL(float2);
|
||||
|
||||
#define INSTANTIATE_FUSED_GATHER_MUL(DTYPE) \
|
||||
template [[host_name("fused_gather_mul_kernel_" #DTYPE)]] kernel void \
|
||||
|
@ -2099,7 +2099,6 @@ class TestSparse(TestSparseBase):
|
||||
self.assertEqual(self.safeToDense(y2), expected)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_sparse_mask(self, device, dtype, coalesced):
|
||||
|
Reference in New Issue
Block a user