mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] coalesce for sparse tensors (#159729)
MPS coalesce function for sparse tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/159729 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
556e2a73f4
commit
7f4cb4a3e0
@ -119,6 +119,8 @@ file(GLOB_RECURSE native_mps_cpp "native/mps/*.cpp")
|
|||||||
file(GLOB_RECURSE native_mps_mm "native/mps/*.mm")
|
file(GLOB_RECURSE native_mps_mm "native/mps/*.mm")
|
||||||
file(GLOB_RECURSE native_mps_metal "native/mps/*.metal")
|
file(GLOB_RECURSE native_mps_metal "native/mps/*.metal")
|
||||||
file(GLOB_RECURSE native_mps_h "native/mps/*.h")
|
file(GLOB_RECURSE native_mps_h "native/mps/*.h")
|
||||||
|
file(GLOB_RECURSE native_sparse_mps_mm "native/sparse/mps/*.mm")
|
||||||
|
file(GLOB_RECURSE native_mps_sparse_metal "native/sparse/mps/*.metal")
|
||||||
|
|
||||||
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
|
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
|
||||||
file(GLOB native_quantized_cpp
|
file(GLOB native_quantized_cpp
|
||||||
@ -699,10 +701,10 @@ endif()
|
|||||||
if(USE_MPS)
|
if(USE_MPS)
|
||||||
include(../../../cmake/Metal.cmake)
|
include(../../../cmake/Metal.cmake)
|
||||||
|
|
||||||
set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h})
|
set(ATen_MPS_SRCS ${ATen_MPS_SRCS} ${mps_cpp} ${mps_mm} ${mps_h} ${native_mps_cpp} ${native_mps_mm} ${native_mps_h} ${native_sparse_mps_mm})
|
||||||
|
|
||||||
if(CAN_COMPILE_METAL)
|
if(CAN_COMPILE_METAL)
|
||||||
foreach(SHADER ${native_mps_metal})
|
foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal})
|
||||||
cmake_path(GET SHADER STEM TGT_STEM)
|
cmake_path(GET SHADER STEM TGT_STEM)
|
||||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
|
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
|
||||||
list(APPEND AIR_BASIC ${TGT_BASIC})
|
list(APPEND AIR_BASIC ${TGT_BASIC})
|
||||||
@ -717,7 +719,7 @@ if(USE_MPS)
|
|||||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
|
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
|
||||||
else()
|
else()
|
||||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
|
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
|
||||||
foreach(SHADER ${native_mps_metal})
|
foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal})
|
||||||
cmake_path(GET SHADER STEM TGT_STEM)
|
cmake_path(GET SHADER STEM TGT_STEM)
|
||||||
string(CONCAT SHADER_HDR_NAME "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h")
|
string(CONCAT SHADER_HDR_NAME "${CMAKE_CURRENT_BINARY_DIR}" /native/mps/ ${TGT_STEM} "_metallib.h")
|
||||||
metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME})
|
metal_to_metallib_h(${SHADER} ${SHADER_HDR_NAME})
|
||||||
|
@ -7423,6 +7423,7 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
SparseCPU: _coalesce_sparse_cpu
|
SparseCPU: _coalesce_sparse_cpu
|
||||||
SparseCUDA: _coalesce_sparse_cuda
|
SparseCUDA: _coalesce_sparse_cuda
|
||||||
|
SparseMPS: _coalesce_sparse_mps
|
||||||
autogen: _coalesce.out
|
autogen: _coalesce.out
|
||||||
|
|
||||||
- func: is_coalesced(Tensor self) -> bool
|
- func: is_coalesced(Tensor self) -> bool
|
||||||
|
220
aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm
Normal file
220
aten/src/ATen/native/sparse/mps/SparseMPSTensor.mm
Normal file
@ -0,0 +1,220 @@
|
|||||||
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
|
#include <ATen/native/SparseTensorUtils.h>
|
||||||
|
#include <ATen/native/mps/OperationUtils.h>
|
||||||
|
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
#include <ATen/Functions.h>
|
||||||
|
#include <ATen/NativeFunctions.h>
|
||||||
|
#else
|
||||||
|
#include <ATen/ops/_coalesce_native.h>
|
||||||
|
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
|
||||||
|
#include <ATen/ops/empty_native.h>
|
||||||
|
#include <ATen/ops/zeros_native.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace at::native {
|
||||||
|
|
||||||
|
using namespace mps;
|
||||||
|
using namespace at::sparse;
|
||||||
|
|
||||||
|
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||||
|
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||||
|
#else
|
||||||
|
#include <ATen/native/mps/Sparse_metallib.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
static Tensor flatten_indices(const Tensor& indices, IntArrayRef size) {
|
||||||
|
|
||||||
|
TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D");
|
||||||
|
TORCH_CHECK(static_cast<size_t>(indices.size(0)) == size.size(),
|
||||||
|
"flatten_indices: indices.size(0) must equal size.size()");
|
||||||
|
|
||||||
|
int64_t sparse_dim = indices.size(0);
|
||||||
|
int64_t nnz = indices.size(1);
|
||||||
|
|
||||||
|
if (nnz == 0) {
|
||||||
|
return at::empty({0}, indices.options().dtype(kLong));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> strides(sparse_dim);
|
||||||
|
strides[sparse_dim - 1] = 1;
|
||||||
|
for (int64_t i = sparse_dim - 2; i >= 0; i--) {
|
||||||
|
strides[i] = strides[i + 1] * size[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor flat_indices = at::empty({nnz}, indices.options().dtype(kLong));
|
||||||
|
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel");
|
||||||
|
auto encoder = stream->commandEncoder();
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
|
||||||
|
mtl_setArgs(encoder, indices, strides, flat_indices, sparse_dim, nnz);
|
||||||
|
mtl_dispatch1DJob(encoder, pipeline, nnz);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return flat_indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Tensor compute_output_positions(const Tensor& is_unique) {
|
||||||
|
|
||||||
|
int64_t nnz = is_unique.size(0);
|
||||||
|
if (nnz == 0) {
|
||||||
|
return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt));
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor positions = at::empty({nnz}, TensorOptions().device(kMPS).dtype(kInt));
|
||||||
|
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pipeline = lib.getPipelineStateForFunc("compute_output_positions_kernel");
|
||||||
|
auto encoder = stream->commandEncoder();
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
|
||||||
|
mtl_setArgs(encoder, is_unique, positions);
|
||||||
|
mtl_dispatch1DJob(encoder, pipeline, nnz);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return positions;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Tensor compute_output_positions_parallel(const Tensor& is_unique) {
|
||||||
|
|
||||||
|
int64_t nnz = is_unique.size(0);
|
||||||
|
if (nnz == 0) {
|
||||||
|
return at::empty({0}, TensorOptions().device(kMPS).dtype(kInt));
|
||||||
|
}
|
||||||
|
|
||||||
|
// for small arrays, use simple kernel
|
||||||
|
// speed of the naive kernel drops off after 4096 nnz elements
|
||||||
|
if (nnz <= 4096) {
|
||||||
|
return compute_output_positions(is_unique);
|
||||||
|
}
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
Tensor positions = is_unique.to(kInt);
|
||||||
|
// Kogge-Stone parallel prefix sum
|
||||||
|
Tensor positions_cloned = positions.clone();
|
||||||
|
|
||||||
|
for (int64_t stride = 1; stride < nnz; stride *= 2) {
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pipeline = lib.getPipelineStateForFunc("kogge_stone_step");
|
||||||
|
auto encoder = stream->commandEncoder();
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
|
||||||
|
mtl_setArgs(encoder, positions, positions_cloned, stride);
|
||||||
|
mtl_dispatch1DJob(encoder, pipeline, nnz);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
std::swap(positions, positions_cloned);
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pipeline = lib.getPipelineStateForFunc("shift_right_kernel");
|
||||||
|
auto encoder = stream->commandEncoder();
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
|
||||||
|
mtl_setArgs(encoder, positions, positions_cloned);
|
||||||
|
mtl_dispatch1DJob(encoder, pipeline, nnz);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return positions_cloned;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<Tensor, int32_t> mark_unique_and_count(const Tensor& flat_indices) {
|
||||||
|
|
||||||
|
int64_t nnz = flat_indices.size(0);
|
||||||
|
if (nnz == 0) {
|
||||||
|
return {at::empty({0}, flat_indices.options().dtype(kBool)), 0};
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor is_unique = at::empty({nnz}, flat_indices.options().dtype(kBool));
|
||||||
|
Tensor count_result = at::zeros({1}, flat_indices.options().dtype(kInt));
|
||||||
|
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pipeline = lib.getPipelineStateForFunc("mark_unique_positions_and_count_kernel");
|
||||||
|
auto encoder = stream->commandEncoder();
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
|
||||||
|
mtl_setArgs(encoder, flat_indices, is_unique, count_result);
|
||||||
|
mtl_dispatch1DJob(encoder, pipeline, nnz);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
int32_t num_unique = count_result.item<int32_t>();
|
||||||
|
|
||||||
|
return {is_unique, num_unique};
|
||||||
|
}
|
||||||
|
|
||||||
|
SparseTensor _coalesce_sparse_mps(const SparseTensor& self) {
|
||||||
|
int64_t nnz = self._nnz();
|
||||||
|
TORCH_INTERNAL_ASSERT(!self.is_coalesced());
|
||||||
|
if (nnz < 2) {
|
||||||
|
SparseTensor dst = self.clone();
|
||||||
|
dst._coalesced_(true);
|
||||||
|
return dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor indices = self._indices();
|
||||||
|
Tensor values = self._values();
|
||||||
|
|
||||||
|
Tensor flat_indices = flatten_indices(indices, self.sizes());
|
||||||
|
Tensor sorted_order = flat_indices.argsort();
|
||||||
|
Tensor flat_indices_sorted = flat_indices.index({sorted_order});
|
||||||
|
values = values.index({sorted_order});
|
||||||
|
indices = indices.index_select(1, sorted_order);
|
||||||
|
|
||||||
|
auto unique_info = mark_unique_and_count(flat_indices_sorted);
|
||||||
|
Tensor is_unique = unique_info.first;
|
||||||
|
int32_t newNnz = unique_info.second;
|
||||||
|
|
||||||
|
Tensor output_positions = compute_output_positions_parallel(is_unique);
|
||||||
|
|
||||||
|
Tensor out_indices = at::empty({indices.size(0), newNnz}, indices.options());
|
||||||
|
auto outValuesSize = values.sizes().vec();
|
||||||
|
outValuesSize[0] = newNnz;
|
||||||
|
Tensor out_values = at::zeros(outValuesSize, values.options());
|
||||||
|
|
||||||
|
Tensor is_unique_local = is_unique;
|
||||||
|
int64_t sparse_dim = indices.size(0);
|
||||||
|
|
||||||
|
auto stream = getCurrentMPSStream();
|
||||||
|
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||||
|
@autoreleasepool {
|
||||||
|
auto pipeline = lib.getPipelineStateForFunc("coalesce_with_positions_kernel_" + scalarToMetalTypeString(values));
|
||||||
|
auto encoder = stream->commandEncoder();
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
|
||||||
|
const uint32_t numThreads = static_cast<uint32_t>(nnz);
|
||||||
|
const uint32_t valueSize = static_cast<uint32_t>(values.numel() / nnz);
|
||||||
|
mtl_setArgs(encoder,
|
||||||
|
flat_indices_sorted,
|
||||||
|
indices,
|
||||||
|
values,
|
||||||
|
is_unique_local,
|
||||||
|
output_positions,
|
||||||
|
out_indices,
|
||||||
|
out_values,
|
||||||
|
numThreads,
|
||||||
|
valueSize,
|
||||||
|
sparse_dim,
|
||||||
|
newNnz);
|
||||||
|
mtl_dispatch1DJob(encoder, pipeline, nnz);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
SparseTensor result = _sparse_coo_tensor_unsafe_symint(out_indices, out_values, self.sym_sizes())._coalesced_(true);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace at::native
|
123
aten/src/ATen/native/sparse/mps/kernels/Sparse.metal
Normal file
123
aten/src/ATen/native/sparse/mps/kernels/Sparse.metal
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
#include <metal_atomic>
|
||||||
|
#include <metal_stdlib>
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
kernel void flatten_indices_kernel(
|
||||||
|
device const int64_t* indices [[buffer(0)]],
|
||||||
|
device const int64_t* strides [[buffer(1)]],
|
||||||
|
device int64_t* flat_indices [[buffer(2)]],
|
||||||
|
constant uint& sparse_dim [[buffer(3)]],
|
||||||
|
constant uint& nnz [[buffer(4)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
int64_t flat_idx = 0;
|
||||||
|
for (uint d = 0; d < sparse_dim; d++) {
|
||||||
|
flat_idx += indices[d * nnz + gid] * strides[d];
|
||||||
|
}
|
||||||
|
flat_indices[gid] = flat_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void compute_output_positions_kernel(
|
||||||
|
device const bool* is_unique [[buffer(0)]],
|
||||||
|
device int* positions [[buffer(1)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
int pos = 0;
|
||||||
|
for (uint i = 0; i < gid; i++) {
|
||||||
|
if (is_unique[i])
|
||||||
|
pos++;
|
||||||
|
}
|
||||||
|
positions[gid] = pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void mark_unique_positions_and_count_kernel(
|
||||||
|
device const int64_t* flat_indices [[buffer(0)]],
|
||||||
|
device bool* is_unique [[buffer(1)]],
|
||||||
|
device atomic_int* count [[buffer(2)]],
|
||||||
|
uint tid [[thread_position_in_grid]]) {
|
||||||
|
bool unique = (tid == 0) || (flat_indices[tid] != flat_indices[tid - 1]);
|
||||||
|
is_unique[tid] = unique;
|
||||||
|
|
||||||
|
if (unique) {
|
||||||
|
atomic_fetch_add_explicit(count, 1, memory_order_relaxed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kogge-Stone parallel prefix sum step
|
||||||
|
kernel void kogge_stone_step(
|
||||||
|
device const int* input [[buffer(0)]],
|
||||||
|
device int* output [[buffer(1)]],
|
||||||
|
constant uint& stride [[buffer(2)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
int val = input[gid];
|
||||||
|
if (gid >= stride) {
|
||||||
|
val += input[gid - stride];
|
||||||
|
}
|
||||||
|
output[gid] = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shift right for exclusive scan
|
||||||
|
kernel void shift_right_kernel(
|
||||||
|
device const int* input [[buffer(0)]],
|
||||||
|
device int* output [[buffer(1)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
output[gid] = (gid == 0) ? 0 : input[gid - 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
kernel void coalesce_with_positions_kernel(
|
||||||
|
device const int64_t* flat_indices [[buffer(0)]],
|
||||||
|
device const int64_t* indices [[buffer(1)]],
|
||||||
|
device const T* in_values [[buffer(2)]],
|
||||||
|
device const bool* is_unique [[buffer(3)]],
|
||||||
|
device const int* output_positions [[buffer(4)]],
|
||||||
|
device int64_t* out_indices [[buffer(5)]],
|
||||||
|
device T* out_values [[buffer(6)]],
|
||||||
|
constant uint& nnz [[buffer(7)]],
|
||||||
|
constant uint& value_size [[buffer(8)]],
|
||||||
|
constant uint& sparse_dim [[buffer(9)]],
|
||||||
|
constant uint& total_unique [[buffer(10)]],
|
||||||
|
uint gid [[thread_position_in_grid]]) {
|
||||||
|
if (!is_unique[gid])
|
||||||
|
return;
|
||||||
|
|
||||||
|
int out_pos = output_positions[gid];
|
||||||
|
|
||||||
|
for (uint d = 0; d < sparse_dim; d++) {
|
||||||
|
out_indices[d * total_unique + out_pos] = indices[d * nnz + gid];
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t current_index = flat_indices[gid];
|
||||||
|
uint end = gid + 1;
|
||||||
|
while (end < nnz && flat_indices[end] == current_index) {
|
||||||
|
end++;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint elem = 0; elem < value_size; elem++) {
|
||||||
|
T sum = 0;
|
||||||
|
for (uint j = gid; j < end; j++) {
|
||||||
|
sum += in_values[j * value_size + elem];
|
||||||
|
}
|
||||||
|
out_values[out_pos * value_size + elem] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define INSTANTIATE_COALESCE_WITH_POSITIONS(DTYPE) \
|
||||||
|
template \
|
||||||
|
[[host_name("coalesce_with_positions_kernel_" #DTYPE)]] [[kernel]] void \
|
||||||
|
coalesce_with_positions_kernel<DTYPE>( \
|
||||||
|
device const int64_t* flat_indices [[buffer(0)]], \
|
||||||
|
device const int64_t* indices [[buffer(1)]], \
|
||||||
|
device const DTYPE* in_values [[buffer(2)]], \
|
||||||
|
device const bool* is_unique [[buffer(3)]], \
|
||||||
|
device const int* output_positions [[buffer(4)]], \
|
||||||
|
device int64_t* out_indices [[buffer(5)]], \
|
||||||
|
device DTYPE* out_values [[buffer(6)]], \
|
||||||
|
constant uint& nnz [[buffer(7)]], \
|
||||||
|
constant uint& value_size [[buffer(8)]], \
|
||||||
|
constant uint& sparse_dim [[buffer(9)]], \
|
||||||
|
constant uint& total_unique [[buffer(10)]], \
|
||||||
|
uint gid [[thread_position_in_grid]]);
|
||||||
|
|
||||||
|
INSTANTIATE_COALESCE_WITH_POSITIONS(float);
|
||||||
|
INSTANTIATE_COALESCE_WITH_POSITIONS(half);
|
||||||
|
INSTANTIATE_COALESCE_WITH_POSITIONS(bfloat);
|
||||||
|
INSTANTIATE_COALESCE_WITH_POSITIONS(bool);
|
@ -237,8 +237,6 @@ inline DeviceType backendToDeviceType(Backend b) {
|
|||||||
return DeviceType::CPU;
|
return DeviceType::CPU;
|
||||||
case Backend::CUDA:
|
case Backend::CUDA:
|
||||||
case Backend::SparseCUDA:
|
case Backend::SparseCUDA:
|
||||||
case Backend::SparseMPS:
|
|
||||||
case Backend::SparseCsrMPS:
|
|
||||||
case Backend::QuantizedCUDA:
|
case Backend::QuantizedCUDA:
|
||||||
case Backend::SparseCsrCUDA:
|
case Backend::SparseCsrCUDA:
|
||||||
return DeviceType::CUDA;
|
return DeviceType::CUDA;
|
||||||
@ -276,6 +274,8 @@ inline DeviceType backendToDeviceType(Backend b) {
|
|||||||
case Backend::Meta:
|
case Backend::Meta:
|
||||||
return DeviceType::Meta;
|
return DeviceType::Meta;
|
||||||
case Backend::MPS:
|
case Backend::MPS:
|
||||||
|
case Backend::SparseMPS:
|
||||||
|
case Backend::SparseCsrMPS:
|
||||||
return DeviceType::MPS;
|
return DeviceType::MPS;
|
||||||
case Backend::HPU:
|
case Backend::HPU:
|
||||||
return DeviceType::HPU;
|
return DeviceType::HPU;
|
||||||
|
@ -33,7 +33,6 @@ inline Layout layout_from_backend(Backend backend) {
|
|||||||
case Backend::SparseCPU:
|
case Backend::SparseCPU:
|
||||||
case Backend::SparseCUDA:
|
case Backend::SparseCUDA:
|
||||||
case Backend::SparseMPS:
|
case Backend::SparseMPS:
|
||||||
case Backend::SparseCsrMPS:
|
|
||||||
case Backend::SparseHIP:
|
case Backend::SparseHIP:
|
||||||
case Backend::SparseVE:
|
case Backend::SparseVE:
|
||||||
case Backend::SparseXPU:
|
case Backend::SparseXPU:
|
||||||
@ -43,6 +42,7 @@ inline Layout layout_from_backend(Backend backend) {
|
|||||||
return Layout::Mkldnn;
|
return Layout::Mkldnn;
|
||||||
case Backend::SparseCsrCPU:
|
case Backend::SparseCsrCPU:
|
||||||
case Backend::SparseCsrCUDA:
|
case Backend::SparseCsrCUDA:
|
||||||
|
case Backend::SparseCsrMPS:
|
||||||
case Backend::SparseCsrHIP:
|
case Backend::SparseCsrHIP:
|
||||||
case Backend::SparseCsrVE:
|
case Backend::SparseCsrVE:
|
||||||
case Backend::SparseCsrXPU:
|
case Backend::SparseCsrXPU:
|
||||||
|
@ -2090,6 +2090,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||||||
constexpr auto sparse_backends = DispatchKeySet(
|
constexpr auto sparse_backends = DispatchKeySet(
|
||||||
{BackendComponent::CPUBit,
|
{BackendComponent::CPUBit,
|
||||||
BackendComponent::CUDABit,
|
BackendComponent::CUDABit,
|
||||||
|
BackendComponent::MPSBit,
|
||||||
BackendComponent::HIPBit,
|
BackendComponent::HIPBit,
|
||||||
BackendComponent::XPUBit});
|
BackendComponent::XPUBit});
|
||||||
constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);
|
constexpr auto sparse_k = DispatchKeySet(DispatchKey::Sparse);
|
||||||
|
@ -12696,6 +12696,65 @@ class TestSparseMPS(TestCaseMPS):
|
|||||||
sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
|
sparse_cpu = sparse_cpu.sparse_resize_(torch.Size([4, 5]), sparse_dim=2, dense_dim=0)
|
||||||
self.assertEqual(sparse, sparse_cpu)
|
self.assertEqual(sparse, sparse_cpu)
|
||||||
|
|
||||||
|
def test_coalesce(self):
|
||||||
|
indices = torch.tensor([[0, 0, 1, 1], [0, 0, 2, 2]], dtype=torch.int64, device="mps")
|
||||||
|
values = torch.tensor([1., 2., 3., 4.], dtype=torch.float32, device="mps")
|
||||||
|
size = (2, 3)
|
||||||
|
indices_cpu = indices.cpu()
|
||||||
|
values_cpu = values.cpu()
|
||||||
|
sparse_mps = torch.sparse_coo_tensor(indices, values, size, device="mps")
|
||||||
|
sparse_cpu = torch.sparse_coo_tensor(indices_cpu, values_cpu, size, device="cpu")
|
||||||
|
coalesced_mps = sparse_mps.coalesce()
|
||||||
|
coalesced_cpu = sparse_cpu.coalesce()
|
||||||
|
|
||||||
|
self.assertTrue(coalesced_mps.is_coalesced())
|
||||||
|
self.assertTrue(coalesced_cpu.is_coalesced())
|
||||||
|
self.assertEqual(coalesced_mps._nnz(), 2)
|
||||||
|
self.assertEqual(coalesced_mps.cpu(), coalesced_cpu)
|
||||||
|
|
||||||
|
def test_already_coalesced_tensor(self):
|
||||||
|
already_coalesced = self._get_basic_sparse_coo()
|
||||||
|
result = already_coalesced.coalesce()
|
||||||
|
self.assertTrue(result.is_coalesced())
|
||||||
|
self.assertEqual(result._indices().cpu(), already_coalesced._indices().cpu())
|
||||||
|
self.assertEqual(result._values().cpu(), already_coalesced._values().cpu())
|
||||||
|
|
||||||
|
def test_coalesce_empty_sparse_tensor(self):
|
||||||
|
empty_indices = torch.zeros((2, 0), dtype=torch.int64, device="mps")
|
||||||
|
empty_values = torch.tensor([], dtype=torch.float32, device="mps")
|
||||||
|
empty_sparse = torch.sparse_coo_tensor(empty_indices, empty_values, (3, 3), device="mps")
|
||||||
|
empty_coalesced = empty_sparse.coalesce()
|
||||||
|
self.assertTrue(empty_coalesced.is_coalesced())
|
||||||
|
self.assertEqual(empty_coalesced._nnz(), 0)
|
||||||
|
|
||||||
|
def test_coalesce_large_tensor(self):
|
||||||
|
size = (1000000, 1000000)
|
||||||
|
num_elements = 1000
|
||||||
|
|
||||||
|
# 800 unique random positions
|
||||||
|
unique_indices = torch.randint(0, size[0], (2, 800), dtype=torch.int64)
|
||||||
|
# 200 duplicates by repeating some of the first 200 indices
|
||||||
|
duplicate_indices = unique_indices[:, :200]
|
||||||
|
indices = torch.cat([unique_indices, duplicate_indices], dim=1)
|
||||||
|
# shuffle indices to mix duplicates with unique entries
|
||||||
|
perm = torch.randperm(indices.size(1))
|
||||||
|
indices = indices[:, perm]
|
||||||
|
|
||||||
|
values = torch.randn(num_elements, dtype=torch.float32)
|
||||||
|
indices_mps = indices.to("mps")
|
||||||
|
values_mps = values.to("mps")
|
||||||
|
sparse_mps = torch.sparse_coo_tensor(indices_mps, values_mps, size, device="mps")
|
||||||
|
sparse_cpu = torch.sparse_coo_tensor(indices, values, size, device="cpu")
|
||||||
|
|
||||||
|
self.assertFalse(sparse_mps.is_coalesced())
|
||||||
|
coalesced_mps = sparse_mps.coalesce()
|
||||||
|
coalesced_cpu = sparse_cpu.coalesce()
|
||||||
|
self.assertTrue(coalesced_mps.is_coalesced())
|
||||||
|
self.assertTrue(coalesced_cpu.is_coalesced())
|
||||||
|
self.assertEqual(coalesced_mps._nnz(), coalesced_cpu._nnz())
|
||||||
|
self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices())
|
||||||
|
self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values())
|
||||||
|
|
||||||
|
|
||||||
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
|
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
|
||||||
# This requires mps to be properly registered in the device generic test framework which is not the
|
# This requires mps to be properly registered in the device generic test framework which is not the
|
||||||
|
@ -2849,14 +2849,13 @@ def main() -> None:
|
|||||||
# TODO: stop generating CUDA kernels for non-CUDA builds
|
# TODO: stop generating CUDA kernels for non-CUDA builds
|
||||||
ignore_keys = set()
|
ignore_keys = set()
|
||||||
|
|
||||||
|
MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS}
|
||||||
if options.mps or options.update_aoti_c_shim:
|
if options.mps or options.update_aoti_c_shim:
|
||||||
functions_keys.add(DispatchKey.MPS)
|
functions_keys.update(MPS_KEYS)
|
||||||
aoti_backends.add(DispatchKey.MPS)
|
aoti_backends.add(DispatchKey.MPS)
|
||||||
else:
|
else:
|
||||||
ignore_keys.add(DispatchKey.MPS)
|
ignore_keys.update(MPS_KEYS)
|
||||||
|
dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS]
|
||||||
if DispatchKey.MPS in dispatch_keys:
|
|
||||||
del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
|
|
||||||
|
|
||||||
if options.xpu or options.update_aoti_c_shim:
|
if options.xpu or options.update_aoti_c_shim:
|
||||||
functions_keys.add(DispatchKey.XPU)
|
functions_keys.add(DispatchKey.XPU)
|
||||||
|
Reference in New Issue
Block a user