[MPS] [Sparse] unique_dim and sparse broadcast (#163694)

Implements unique_dim, sparse broadcast ops and adds dtypes for mps for tests where we expect to fail, otherwise they would always fail due to being run in double precision

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163694
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20
2025-09-26 23:03:13 +00:00
committed by PyTorch MergeBot
parent 19f16a65b4
commit 62b0ebd8f9
4 changed files with 104 additions and 8 deletions

View File

@ -9,11 +9,22 @@
#else
#include <ATen/ops/_unique2.h>
#include <ATen/ops/_unique2_native.h>
#include <ATen/ops/arange.h>
#include <ATen/ops/argsort.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/cumsum.h>
#include <ATen/ops/full.h>
#include <ATen/ops/masked_select.h>
#include <ATen/ops/nonzero.h>
#include <ATen/ops/ones.h>
#include <ATen/ops/ones_like.h>
#include <ATen/ops/slice.h>
#include <ATen/ops/unique_consecutive.h>
#include <ATen/ops/unique_consecutive_native.h>
#include <ATen/ops/unique_dim_consecutive.h>
#include <ATen/ops/unique_dim_consecutive_native.h>
#include <ATen/ops/unique_dim_native.h>
#include <ATen/ops/zeros.h>
#endif
namespace at::native {
@ -305,4 +316,85 @@ std::tuple<Tensor, Tensor, Tensor> _unique2_mps(const Tensor& self,
return _unique_impl_mps(self, return_inverse, return_counts, false, std::nullopt);
}
static Tensor lexsort_rows_perm_mps(const Tensor& mat_2d) {
const auto rows = mat_2d.size(0), cols = mat_2d.size(1);
if (rows <= 1 || cols == 0) {
return arange(rows, mat_2d.options().dtype(kLong));
}
auto perm = arange(rows, mat_2d.options().dtype(kLong));
for (auto c = cols - 1; c >= 0; --c) {
auto keys = mat_2d.select(1, c).index_select(0, perm);
const auto idx = argsort(keys, /*dim=*/0, /*descending=*/false);
perm = perm.index_select(0, idx);
}
return perm;
}
static std::tuple<Tensor, Tensor, Tensor> unique_dim_sorted_mps_impl(const Tensor& self,
int64_t dim,
bool return_inverse,
bool return_counts) {
dim = maybe_wrap_dim(dim, self.dim());
auto sizes = self.sizes().vec();
auto num_zero_dims = std::count(sizes.begin(), sizes.end(), (int64_t)0);
if (self.size(dim) == 0) {
auto output = at::empty(sizes, self.options());
auto inverse_indices = at::empty({0}, self.options().dtype(kLong));
auto counts = at::empty({0}, self.options().dtype(kLong));
return {output, inverse_indices, counts};
}
auto transposed = self.moveaxis(dim, 0);
auto orig_sizes = transposed.sizes().vec();
auto rows = transposed.size(0);
auto input_flat = transposed.contiguous().view({rows, -1});
auto perm = lexsort_rows_perm_mps(input_flat);
auto input_sorted = input_flat.index_select(0, perm);
Tensor is_unique = at::zeros({rows}, self.options().dtype(kBool));
if (rows > 0) {
is_unique.narrow(0, 0, 1).fill_(true);
}
if (rows > 1) {
auto a = input_sorted.narrow(0, 1, rows - 1);
auto b = input_sorted.narrow(0, 0, rows - 1);
auto row_changed = a.ne(b).any(1);
is_unique.narrow(0, 1, rows - 1).copy_(row_changed);
}
auto unique_pos = nonzero(is_unique).squeeze(1);
auto group_id = cumsum(is_unique.to(kLong), 0).sub(1);
auto unique_rows_2d = input_sorted.index_select(0, unique_pos);
Tensor inverse_indices = empty({0}, self.options().dtype(kLong));
if (return_inverse) {
inverse_indices = empty({rows}, self.options().dtype(kLong));
inverse_indices.index_copy_(0, perm, group_id);
}
Tensor counts = empty({0}, self.options().dtype(kLong));
if (return_counts) {
const auto num_unique = unique_pos.size(0);
counts = zeros({num_unique}, self.options().dtype(kLong));
counts.scatter_add_(0, group_id, ones_like(group_id, group_id.options().dtype(kLong)));
}
orig_sizes[0] = unique_rows_2d.size(0);
auto output = unique_rows_2d.view(orig_sizes).moveaxis(0, dim);
return std::make_tuple(std::move(output), std::move(inverse_indices), std::move(counts));
}
std::tuple<Tensor, Tensor, Tensor> unique_dim_mps(const Tensor& self,
int64_t dim,
const bool /*sorted*/,
const bool return_inverse,
const bool return_counts) {
return unique_dim_sorted_mps_impl(self, dim, return_inverse, return_counts);
}
} // namespace at::native

View File

@ -1409,7 +1409,7 @@
- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
variants: function
dispatch:
SparseCPU, SparseCUDA: sparse_broadcast_to
SparseCPU, SparseCUDA, SparseMPS: sparse_broadcast_to
- func: cat(Tensor[] tensors, int dim=0) -> Tensor
structured_delegate: cat.out
@ -6450,6 +6450,7 @@
dispatch:
CPU: unique_dim_cpu
CUDA: unique_dim_cuda
MPS: unique_dim_mps
tags: dynamic_output_shape
autogen: unique_dim.out

View File

@ -547,11 +547,12 @@ class TestSparse(TestSparseBase):
@coalescedonoff
@dtypes(torch.float16, torch.bfloat16, torch.float64, torch.int, torch.cfloat, torch.cdouble)
@expectedFailureMPS # unique_dim not implemented for MPS device
@dtypesIfMPS(torch.float16, torch.bfloat16, torch.float32, torch.int, torch.cfloat)
def test_to_sparse(self, device, dtype, coalesced):
shape = [5, 2, 10, 4]
max_nnz = 1
for value_type in [torch.double, torch.cdouble]:
dtypes = [torch.double, torch.cdouble] if device != "mps:0" else [torch.float32, torch.complex64]
for value_type in dtypes:
for dim, dim_sz in enumerate(shape, 1):
max_nnz *= dim_sz
rnnz = torch.randint(2, max_nnz, (1,)).item()
@ -1764,8 +1765,8 @@ class TestSparse(TestSparseBase):
test_shape(1000, 100, 0, 20)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
def test_spadd(self, device, dtype, coalesced):
def _test_spadd_shape(nnz, shape_i, shape_v=None):
@ -3856,8 +3857,8 @@ class TestSparse(TestSparseBase):
self.assertRaises(TypeError, assign_to)
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
def test_full_broadcast_to(self, device, dtype):
def can_broadcast(s0, s1):
s0 = tuple(reversed(s0))
@ -3887,8 +3888,8 @@ class TestSparse(TestSparseBase):
torch._sparse_broadcast_to(s, s1)
@coalescedonoff
@expectedFailureMPS
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float32, torch.complex64)
def test_sparse_broadcast_to(self, device, dtype, coalesced):
def test(sparse_dims, nnz, with_size, new_size):
x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]

View File

@ -383,8 +383,6 @@ if torch.backends.mps.is_available():
"symeig": None,
"take": None,
"to": None,
"to_sparse": None,
"unique": None,
"vdot": None,
"segment_reduce_": None,
"_upsample_bilinear2d_aa": [torch.uint8], # uint8 is for CPU only
@ -758,6 +756,10 @@ if torch.backends.mps.is_available():
"eye": [torch.float16, torch.float32],
# topk fails with duplicate indices
"topk": [torch.float16],
# Could not run 'aten::uniform_' with arguments from the 'SparseCPU' backend
"to_sparse": None,
# Exception: the derivative for '_unique2' is not implemented.
"unique": None,
}
SKIPLIST_GRAD = {