mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] [Sparse] unique_dim and sparse broadcast (#163694)
Implements unique_dim, sparse broadcast ops and adds dtypes for mps for tests where we expect to fail, otherwise they would always fail due to being run in double precision Pull Request resolved: https://github.com/pytorch/pytorch/pull/163694 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
19f16a65b4
commit
62b0ebd8f9
@ -9,11 +9,22 @@
|
||||
#else
|
||||
#include <ATen/ops/_unique2.h>
|
||||
#include <ATen/ops/_unique2_native.h>
|
||||
#include <ATen/ops/arange.h>
|
||||
#include <ATen/ops/argsort.h>
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/cumsum.h>
|
||||
#include <ATen/ops/full.h>
|
||||
#include <ATen/ops/masked_select.h>
|
||||
#include <ATen/ops/nonzero.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
#include <ATen/ops/ones_like.h>
|
||||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/unique_consecutive.h>
|
||||
#include <ATen/ops/unique_consecutive_native.h>
|
||||
#include <ATen/ops/unique_dim_consecutive.h>
|
||||
#include <ATen/ops/unique_dim_consecutive_native.h>
|
||||
#include <ATen/ops/unique_dim_native.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
@ -305,4 +316,85 @@ std::tuple<Tensor, Tensor, Tensor> _unique2_mps(const Tensor& self,
|
||||
return _unique_impl_mps(self, return_inverse, return_counts, false, std::nullopt);
|
||||
}
|
||||
|
||||
static Tensor lexsort_rows_perm_mps(const Tensor& mat_2d) {
|
||||
const auto rows = mat_2d.size(0), cols = mat_2d.size(1);
|
||||
if (rows <= 1 || cols == 0) {
|
||||
return arange(rows, mat_2d.options().dtype(kLong));
|
||||
}
|
||||
|
||||
auto perm = arange(rows, mat_2d.options().dtype(kLong));
|
||||
for (auto c = cols - 1; c >= 0; --c) {
|
||||
auto keys = mat_2d.select(1, c).index_select(0, perm);
|
||||
const auto idx = argsort(keys, /*dim=*/0, /*descending=*/false);
|
||||
perm = perm.index_select(0, idx);
|
||||
}
|
||||
return perm;
|
||||
}
|
||||
|
||||
static std::tuple<Tensor, Tensor, Tensor> unique_dim_sorted_mps_impl(const Tensor& self,
|
||||
int64_t dim,
|
||||
bool return_inverse,
|
||||
bool return_counts) {
|
||||
dim = maybe_wrap_dim(dim, self.dim());
|
||||
|
||||
auto sizes = self.sizes().vec();
|
||||
auto num_zero_dims = std::count(sizes.begin(), sizes.end(), (int64_t)0);
|
||||
if (self.size(dim) == 0) {
|
||||
auto output = at::empty(sizes, self.options());
|
||||
auto inverse_indices = at::empty({0}, self.options().dtype(kLong));
|
||||
auto counts = at::empty({0}, self.options().dtype(kLong));
|
||||
return {output, inverse_indices, counts};
|
||||
}
|
||||
|
||||
auto transposed = self.moveaxis(dim, 0);
|
||||
auto orig_sizes = transposed.sizes().vec();
|
||||
auto rows = transposed.size(0);
|
||||
auto input_flat = transposed.contiguous().view({rows, -1});
|
||||
|
||||
auto perm = lexsort_rows_perm_mps(input_flat);
|
||||
auto input_sorted = input_flat.index_select(0, perm);
|
||||
|
||||
Tensor is_unique = at::zeros({rows}, self.options().dtype(kBool));
|
||||
if (rows > 0) {
|
||||
is_unique.narrow(0, 0, 1).fill_(true);
|
||||
}
|
||||
if (rows > 1) {
|
||||
auto a = input_sorted.narrow(0, 1, rows - 1);
|
||||
auto b = input_sorted.narrow(0, 0, rows - 1);
|
||||
auto row_changed = a.ne(b).any(1);
|
||||
is_unique.narrow(0, 1, rows - 1).copy_(row_changed);
|
||||
}
|
||||
|
||||
auto unique_pos = nonzero(is_unique).squeeze(1);
|
||||
auto group_id = cumsum(is_unique.to(kLong), 0).sub(1);
|
||||
|
||||
auto unique_rows_2d = input_sorted.index_select(0, unique_pos);
|
||||
|
||||
Tensor inverse_indices = empty({0}, self.options().dtype(kLong));
|
||||
if (return_inverse) {
|
||||
inverse_indices = empty({rows}, self.options().dtype(kLong));
|
||||
inverse_indices.index_copy_(0, perm, group_id);
|
||||
}
|
||||
|
||||
Tensor counts = empty({0}, self.options().dtype(kLong));
|
||||
if (return_counts) {
|
||||
const auto num_unique = unique_pos.size(0);
|
||||
counts = zeros({num_unique}, self.options().dtype(kLong));
|
||||
counts.scatter_add_(0, group_id, ones_like(group_id, group_id.options().dtype(kLong)));
|
||||
}
|
||||
|
||||
orig_sizes[0] = unique_rows_2d.size(0);
|
||||
auto output = unique_rows_2d.view(orig_sizes).moveaxis(0, dim);
|
||||
|
||||
return std::make_tuple(std::move(output), std::move(inverse_indices), std::move(counts));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> unique_dim_mps(const Tensor& self,
|
||||
int64_t dim,
|
||||
const bool /*sorted*/,
|
||||
const bool return_inverse,
|
||||
const bool return_counts) {
|
||||
return unique_dim_sorted_mps_impl(self, dim, return_inverse, return_counts);
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1409,7 +1409,7 @@
|
||||
- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
|
||||
variants: function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_broadcast_to
|
||||
SparseCPU, SparseCUDA, SparseMPS: sparse_broadcast_to
|
||||
|
||||
- func: cat(Tensor[] tensors, int dim=0) -> Tensor
|
||||
structured_delegate: cat.out
|
||||
@ -6450,6 +6450,7 @@
|
||||
dispatch:
|
||||
CPU: unique_dim_cpu
|
||||
CUDA: unique_dim_cuda
|
||||
MPS: unique_dim_mps
|
||||
tags: dynamic_output_shape
|
||||
autogen: unique_dim.out
|
||||
|
||||
|
@ -547,11 +547,12 @@ class TestSparse(TestSparseBase):
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.float16, torch.bfloat16, torch.float64, torch.int, torch.cfloat, torch.cdouble)
|
||||
@expectedFailureMPS # unique_dim not implemented for MPS device
|
||||
@dtypesIfMPS(torch.float16, torch.bfloat16, torch.float32, torch.int, torch.cfloat)
|
||||
def test_to_sparse(self, device, dtype, coalesced):
|
||||
shape = [5, 2, 10, 4]
|
||||
max_nnz = 1
|
||||
for value_type in [torch.double, torch.cdouble]:
|
||||
dtypes = [torch.double, torch.cdouble] if device != "mps:0" else [torch.float32, torch.complex64]
|
||||
for value_type in dtypes:
|
||||
for dim, dim_sz in enumerate(shape, 1):
|
||||
max_nnz *= dim_sz
|
||||
rnnz = torch.randint(2, max_nnz, (1,)).item()
|
||||
@ -1764,8 +1765,8 @@ class TestSparse(TestSparseBase):
|
||||
test_shape(1000, 100, 0, 20)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double)
|
||||
@dtypesIfMPS(torch.float32)
|
||||
def test_spadd(self, device, dtype, coalesced):
|
||||
|
||||
def _test_spadd_shape(nnz, shape_i, shape_v=None):
|
||||
@ -3856,8 +3857,8 @@ class TestSparse(TestSparseBase):
|
||||
|
||||
self.assertRaises(TypeError, assign_to)
|
||||
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_full_broadcast_to(self, device, dtype):
|
||||
def can_broadcast(s0, s1):
|
||||
s0 = tuple(reversed(s0))
|
||||
@ -3887,8 +3888,8 @@ class TestSparse(TestSparseBase):
|
||||
torch._sparse_broadcast_to(s, s1)
|
||||
|
||||
@coalescedonoff
|
||||
@expectedFailureMPS
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
@dtypesIfMPS(torch.float32, torch.complex64)
|
||||
def test_sparse_broadcast_to(self, device, dtype, coalesced):
|
||||
def test(sparse_dims, nnz, with_size, new_size):
|
||||
x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
|
||||
|
@ -383,8 +383,6 @@ if torch.backends.mps.is_available():
|
||||
"symeig": None,
|
||||
"take": None,
|
||||
"to": None,
|
||||
"to_sparse": None,
|
||||
"unique": None,
|
||||
"vdot": None,
|
||||
"segment_reduce_": None,
|
||||
"_upsample_bilinear2d_aa": [torch.uint8], # uint8 is for CPU only
|
||||
@ -758,6 +756,10 @@ if torch.backends.mps.is_available():
|
||||
"eye": [torch.float16, torch.float32],
|
||||
# topk fails with duplicate indices
|
||||
"topk": [torch.float16],
|
||||
# Could not run 'aten::uniform_' with arguments from the 'SparseCPU' backend
|
||||
"to_sparse": None,
|
||||
# Exception: the derivative for '_unique2' is not implemented.
|
||||
"unique": None,
|
||||
}
|
||||
|
||||
SKIPLIST_GRAD = {
|
||||
|
Reference in New Issue
Block a user