[MPS] coalesce for sparse tensors (#159729)

MPS coalesce function for sparse tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159729
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Isalia20
2025-08-08 13:49:55 +00:00
committed by PyTorch MergeBot
parent 556e2a73f4
commit 7f4cb4a3e0
9 changed files with 416 additions and 11 deletions

View File

@ -119,6 +119,8 @@ file(GLOB_RECURSE native_mps_cpp "native/mps/*.cpp")
file(GLOB_RECURSE native_mps_mm "native/mps/*.mm") file(GLOB_RECURSE native_mps_mm "native/mps/*.mm")
file(GLOB_RECURSE native_mps_metal "native/mps/*.metal") file(GLOB_RECURSE native_mps_metal "native/mps/*.metal")
file(GLOB_RECURSE native_mps_h "native/mps/*.h") file(GLOB_RECURSE native_mps_h "native/mps/*.h")
file(GLOB_RECURSE native_sparse_mps_mm "native/sparse/mps/*.mm")
file(GLOB_RECURSE native_mps_sparse_metal "native/sparse/mps/*.metal")
file(GLOB native_sparse_cpp "native/sparse/*.cpp") file(GLOB native_sparse_cpp "native/sparse/*.cpp")
file(GLOB native_quantized_cpp file(GLOB native_quantized_cpp
@ -699,10 +701,10 @@ endif()
if(USE_MPS) if(USE_MPS)
include(../../../cmake/Metal.cmake) include(../../../cmake/Metal.cmake)
set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h}) set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h} ${native_sparse_mps_mm})
if(CAN_COMPILE_METAL) if(CAN_COMPILE_METAL)
foreach(SHADER ${native_mps_metal}) foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal})
cmake_path(GET SHADER STEM TGT_STEM) cmake_path(GET SHADER STEM TGT_STEM)
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air") string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
list(APPEND AIR_BASIC ${TGT_BASIC}) list(APPEND AIR_BASIC ${TGT_BASIC})
@ -717,7 +719,7 @@ if(USE_MPS)
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp) add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
else() else()
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps") file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
foreach(SHADER ${native_mps_metal}) foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal})
cmake_path(GET SHADER STEM TGT_STEM) cmake_path(GET SHADER STEM TGT_STEM)
string(CONCAT SHADER_HDR_NAME "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h") string(CONCAT SHADER_HDR_NAME "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h")
metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME}) metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME})

View File

@ -7423,6 +7423,7 @@
dispatch: dispatch:
SparseCPU: _coalesce_sparse_cpu SparseCPU: _coalesce_sparse_cpu
SparseCUDA: _coalesce_sparse_cuda SparseCUDA: _coalesce_sparse_cuda
SparseMPS: _coalesce_sparse_mps
autogen: _coalesce.out autogen: _coalesce.out
- func: is_coalesced(Tensor self) -> bool - func: is_coalesced(Tensor self) -> bool

View File

@ -0,0 +1,220 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_coalesce_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/zeros_native.h>
#endif
namespace at::native {
using namespace mps;
using namespace at::sparse;
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/Sparse_metallib.h>
#endif
static Tensor flatten_indices(const Tensor& indices, IntArrayRef size) {
TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D");
TORCH_CHECK(static_cast<size_t>(indices.size(0)) == size.size(),
"flatten_indices: indices.size(0) must equal size.size()");
int64_t sparse_dim = indices.size(0);
int64_t nnz = indices.size(1);
if (nnz == 0) {
return at::empty({0}, indices.options().dtype(kLong));
}
std::vector<int64_t> strides(sparse_dim);
strides[sparse_dim - 1] = 1;
for (int64_t i = sparse_dim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * size[i + 1];
}
Tensor flat_indices = at::empty({nnz}, indices.options().dtype(kLong));
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, indices, strides, flat_indices, sparse_dim, nnz);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
return flat_indices;
}
static Tensor compute_output_positions(const Tensor& is_unique) {
int64_t nnz = is_unique.size(0);
if (nnz == 0) {
return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt));
}
Tensor positions = at::empty({nnz}, TensorOptions().device(kMPS).dtype(kInt));
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("compute_output_positions_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, is_unique, positions);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
return positions;
}
static Tensor compute_output_positions_parallel(const Tensor& is_unique) {
int64_t nnz = is_unique.size(0);
if (nnz == 0) {
return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt));
}
// for small arrays, use simple kernel
// speed of the naive kernel drops off after 4096 nnz elements
if (nnz <= 4096) {
return compute_output_positions(is_unique);
}
auto stream = getCurrentMPSStream();
Tensor positions = is_unique.to(kInt);
// Kogge-Stone parallel prefix sum
Tensor positions_cloned = positions.clone();
for (int64_t stride = 1; stride < nnz; stride *= 2) {
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("kogge_stone_step");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, positions, positions_cloned, stride);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
std::swap(positions, positions_cloned);
}
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("shift_right_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, positions, positions_cloned);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
return positions_cloned;
}
static std::pair<Tensor, int32_t> mark_unique_and_count(const Tensor& flat_indices) {
int64_t nnz = flat_indices.size(0);
if (nnz == 0) {
return {at::empty({0}, flat_indices.options().dtype(kBool)), 0};
}
Tensor is_unique = at::empty({nnz}, flat_indices.options().dtype(kBool));
Tensor count_result = at::zeros({1}, flat_indices.options().dtype(kInt));
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("mark_unique_positions_and_count_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, flat_indices, is_unique, count_result);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
int32_t num_unique = count_result.item<int32_t>();
return {is_unique, num_unique};
}
SparseTensor _coalesce_sparse_mps(const SparseTensor& self) {
int64_t nnz = self._nnz();
TORCH_INTERNAL_ASSERT(!self.is_coalesced());
if (nnz < 2) {
SparseTensor dst = self.clone();
dst._coalesced_(true);
return dst;
}
Tensor indices = self._indices();
Tensor values = self._values();
Tensor flat_indices = flatten_indices(indices, self.sizes());
Tensor sorted_order = flat_indices.argsort();
Tensor flat_indices_sorted = flat_indices.index({sorted_order});
values = values.index({sorted_order});
indices = indices.index_select(1, sorted_order);
auto unique_info = mark_unique_and_count(flat_indices_sorted);
Tensor is_unique = unique_info.first;
int32_t newNnz = unique_info.second;
Tensor output_positions = compute_output_positions_parallel(is_unique);
Tensor out_indices = at::empty({indices.size(0), newNnz}, indices.options());
auto outValuesSize = values.sizes().vec();
outValuesSize[0] = newNnz;
Tensor out_values = at::zeros(outValuesSize, values.options());
Tensor is_unique_local = is_unique;
int64_t sparse_dim = indices.size(0);
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("coalesce_with_positions_kernel_" + scalarToMetalTypeString(values));
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
const uint32_t numThreads = static_cast<uint32_t>(nnz);
const uint32_t valueSize = static_cast<uint32_t>(values.numel() / nnz);
mtl_setArgs(encoder,
flat_indices_sorted,
indices,
values,
is_unique_local,
output_positions,
out_indices,
out_values,
numThreads,
valueSize,
sparse_dim,
newNnz);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
SparseTensor result = _sparse_coo_tensor_unsafe_symint(out_indices, out_values, self.sym_sizes())._coalesced_(true);
return result;
}
} // namespace at::native

View File

@ -0,0 +1,123 @@
#include <metal_atomic>
#include <metal_stdlib>
using namespace metal;
kernel void flatten_indices_kernel(
device const int64_t* indices [[buffer(0)]],
device const int64_t* strides [[buffer(1)]],
device int64_t* flat_indices [[buffer(2)]],
constant uint& sparse_dim [[buffer(3)]],
constant uint& nnz [[buffer(4)]],
uint gid [[thread_position_in_grid]]) {
int64_t flat_idx = 0;
for (uint d = 0; d < sparse_dim; d++) {
flat_idx += indices[d * nnz + gid] * strides[d];
}
flat_indices[gid] = flat_idx;
}
kernel void compute_output_positions_kernel(
device const bool* is_unique [[buffer(0)]],
device int* positions [[buffer(1)]],
uint gid [[thread_position_in_grid]]) {
int pos = 0;
for (uint i = 0; i < gid; i++) {
if (is_unique[i])
pos++;
}
positions[gid] = pos;
}
kernel void mark_unique_positions_and_count_kernel(
device const int64_t* flat_indices [[buffer(0)]],
device bool* is_unique [[buffer(1)]],
device atomic_int* count [[buffer(2)]],
uint tid [[thread_position_in_grid]]) {
bool unique = (tid == 0) || (flat_indices[tid] != flat_indices[tid - 1]);
is_unique[tid] = unique;
if (unique) {
atomic_fetch_add_explicit(count, 1, memory_order_relaxed);
}
}
// Kogge-Stone parallel prefix sum step
kernel void kogge_stone_step(
device const int* input [[buffer(0)]],
device int* output [[buffer(1)]],
constant uint& stride [[buffer(2)]],
uint gid [[thread_position_in_grid]]) {
int val = input[gid];
if (gid >= stride) {
val += input[gid - stride];
}
output[gid] = val;
}
// Shift right for exclusive scan
kernel void shift_right_kernel(
device const int* input [[buffer(0)]],
device int* output [[buffer(1)]],
uint gid [[thread_position_in_grid]]) {
output[gid] = (gid == 0) ? 0 : input[gid - 1];
}
template <typename T>
kernel void coalesce_with_positions_kernel(
device const int64_t* flat_indices [[buffer(0)]],
device const int64_t* indices [[buffer(1)]],
device const T* in_values [[buffer(2)]],
device const bool* is_unique [[buffer(3)]],
device const int* output_positions [[buffer(4)]],
device int64_t* out_indices [[buffer(5)]],
device T* out_values [[buffer(6)]],
constant uint& nnz [[buffer(7)]],
constant uint& value_size [[buffer(8)]],
constant uint& sparse_dim [[buffer(9)]],
constant uint& total_unique [[buffer(10)]],
uint gid [[thread_position_in_grid]]) {
if (!is_unique[gid])
return;
int out_pos = output_positions[gid];
for (uint d = 0; d < sparse_dim; d++) {
out_indices[d * total_unique + out_pos] = indices[d * nnz + gid];
}
int64_t current_index = flat_indices[gid];
uint end = gid + 1;
while (end < nnz && flat_indices[end] == current_index) {
end++;
}
for (uint elem = 0; elem < value_size; elem++) {
T sum = 0;
for (uint j = gid; j < end; j++) {
sum += in_values[j * value_size + elem];
}
out_values[out_pos * value_size + elem] = sum;
}
}
#define INSTANTIATE_COALESCE_WITH_POSITIONS(DTYPE) \
template \
[[host_name("coalesce_with_positions_kernel_" #DTYPE)]] [[kernel]] void \
coalesce_with_positions_kernel<DTYPE>( \
device const int64_t* flat_indices [[buffer(0)]], \
device const int64_t* indices [[buffer(1)]], \
device const DTYPE* in_values [[buffer(2)]], \
device const bool* is_unique [[buffer(3)]], \
device const int* output_positions [[buffer(4)]], \
device int64_t* out_indices [[buffer(5)]], \
device DTYPE* out_values [[buffer(6)]], \
constant uint& nnz [[buffer(7)]], \
constant uint& value_size [[buffer(8)]], \
constant uint& sparse_dim [[buffer(9)]], \
constant uint& total_unique [[buffer(10)]], \
uint gid [[thread_position_in_grid]]);
INSTANTIATE_COALESCE_WITH_POSITIONS(float);
INSTANTIATE_COALESCE_WITH_POSITIONS(half);
INSTANTIATE_COALESCE_WITH_POSITIONS(bfloat);
INSTANTIATE_COALESCE_WITH_POSITIONS(bool);

View File

@ -237,8 +237,6 @@ inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::CPU; return DeviceType::CPU;
case Backend::CUDA: case Backend::CUDA:
case Backend::SparseCUDA: case Backend::SparseCUDA:
case Backend::SparseMPS:
case Backend::SparseCsrMPS:
case Backend::QuantizedCUDA: case Backend::QuantizedCUDA:
case Backend::SparseCsrCUDA: case Backend::SparseCsrCUDA:
return DeviceType::CUDA; return DeviceType::CUDA;
@ -276,6 +274,8 @@ inline DeviceType backendToDeviceType(Backend b) {
case Backend::Meta: case Backend::Meta:
return DeviceType::Meta; return DeviceType::Meta;
case Backend::MPS: case Backend::MPS:
case Backend::SparseMPS:
case Backend::SparseCsrMPS:
return DeviceType::MPS; return DeviceType::MPS;
case Backend::HPU: case Backend::HPU:
return DeviceType::HPU; return DeviceType::HPU;

View File

@ -33,7 +33,6 @@ inline Layout layout_from_backend(Backend backend) {
case Backend::SparseCPU: case Backend::SparseCPU:
case Backend::SparseCUDA: case Backend::SparseCUDA:
case Backend::SparseMPS: case Backend::SparseMPS:
case Backend::SparseCsrMPS:
case Backend::SparseHIP: case Backend::SparseHIP:
case Backend::SparseVE: case Backend::SparseVE:
case Backend::SparseXPU: case Backend::SparseXPU:
@ -43,6 +42,7 @@ inline Layout layout_from_backend(Backend backend) {
return Layout::Mkldnn; return Layout::Mkldnn;
case Backend::SparseCsrCPU: case Backend::SparseCsrCPU:
case Backend::SparseCsrCUDA: case Backend::SparseCsrCUDA:
case Backend::SparseCsrMPS:
case Backend::SparseCsrHIP: case Backend::SparseCsrHIP:
case Backend::SparseCsrVE: case Backend::SparseCsrVE:
case Backend::SparseCsrXPU: case Backend::SparseCsrXPU:

View File

@ -2090,6 +2090,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
constexpr auto sparse_backends = DispatchKeySet( constexpr auto sparse_backends = DispatchKeySet(
{BackendComponent::CPUBit, {BackendComponent::CPUBit,
BackendComponent::CUDABit, BackendComponent::CUDABit,
BackendComponent::MPSBit,
BackendComponent::HIPBit, BackendComponent::HIPBit,
BackendComponent::XPUBit}); BackendComponent::XPUBit});
constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse); constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);

View File

@ -12696,6 +12696,65 @@ class TestSparseMPS(TestCaseMPS):
sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0) sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
self.assertEqual(sparse, sparse_cpu) self.assertEqual(sparse, sparse_cpu)
def test_coalesce(self):
indices = torch.tensor([[0, 0, 1, 1], [0, 0, 2, 2]], dtype=torch.int64, device="mps")
values = torch.tensor([1., 2., 3., 4.], dtype=torch.float32, device="mps")
size = (2, 3)
indices_cpu = indices.cpu()
values_cpu = values.cpu()
sparse_mps = torch.sparse_coo_tensor(indices, values, size, device="mps")
sparse_cpu = torch.sparse_coo_tensor(indices_cpu, values_cpu, size, device="cpu")
coalesced_mps = sparse_mps.coalesce()
coalesced_cpu = sparse_cpu.coalesce()
self.assertTrue(coalesced_mps.is_coalesced())
self.assertTrue(coalesced_cpu.is_coalesced())
self.assertEqual(coalesced_mps._nnz(), 2)
self.assertEqual(coalesced_mps.cpu(), coalesced_cpu)
def test_already_coalesced_tensor(self):
already_coalesced = self._get_basic_sparse_coo()
result = already_coalesced.coalesce()
self.assertTrue(result.is_coalesced())
self.assertEqual(result._indices().cpu(), already_coalesced._indices().cpu())
self.assertEqual(result._values().cpu(), already_coalesced._values().cpu())
def test_coalesce_empty_sparse_tensor(self):
empty_indices = torch.zeros((2, 0), dtype=torch.int64, device="mps")
empty_values = torch.tensor([], dtype=torch.float32, device="mps")
empty_sparse = torch.sparse_coo_tensor(empty_indices, empty_values, (3, 3), device="mps")
empty_coalesced = empty_sparse.coalesce()
self.assertTrue(empty_coalesced.is_coalesced())
self.assertEqual(empty_coalesced._nnz(), 0)
def test_coalesce_large_tensor(self):
size = (1000000, 1000000)
num_elements = 1000
# 800 unique random positions
unique_indices = torch.randint(0, size[0], (2, 800), dtype=torch.int64)
# 200 duplicates by repeating some of the first 200 indices
duplicate_indices = unique_indices[:, :200]
indices = torch.cat([unique_indices, duplicate_indices], dim=1)
# shuffle indices to mix duplicates with unique entries
perm = torch.randperm(indices.size(1))
indices = indices[:, perm]
values = torch.randn(num_elements, dtype=torch.float32)
indices_mps = indices.to("mps")
values_mps = values.to("mps")
sparse_mps = torch.sparse_coo_tensor(indices_mps, values_mps, size, device="mps")
sparse_cpu = torch.sparse_coo_tensor(indices, values, size, device="cpu")
self.assertFalse(sparse_mps.is_coalesced())
coalesced_mps = sparse_mps.coalesce()
coalesced_cpu = sparse_cpu.coalesce()
self.assertTrue(coalesced_mps.is_coalesced())
self.assertTrue(coalesced_cpu.is_coalesced())
self.assertEqual(coalesced_mps._nnz(), coalesced_cpu._nnz())
self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices())
self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values())
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. # TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the # This requires mps to be properly registered in the device generic test framework which is not the

View File

@ -2849,14 +2849,13 @@ def main() -> None:
# TODO: stop generating CUDA kernels for non-CUDA builds # TODO: stop generating CUDA kernels for non-CUDA builds
ignore_keys = set() ignore_keys = set()
MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS}
if options.mps or options.update_aoti_c_shim: if options.mps or options.update_aoti_c_shim:
functions_keys.add(DispatchKey.MPS) functions_keys.update(MPS_KEYS)
aoti_backends.add(DispatchKey.MPS) aoti_backends.add(DispatchKey.MPS)
else: else:
ignore_keys.add(DispatchKey.MPS) ignore_keys.update(MPS_KEYS)
dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS]
if DispatchKey.MPS in dispatch_keys:
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
if options.xpu or options.update_aoti_c_shim: if options.xpu or options.update_aoti_c_shim:
functions_keys.add(DispatchKey.XPU) functions_keys.add(DispatchKey.XPU)