Revert "[MPS] sparse add unary funcs + add for sparse tensors (#160839)"

This reverts commit 93c5112f46a978a029644ae599979416ead5c917.

Reverted https://github.com/pytorch/pytorch/pull/160839 on behalf of https://github.com/atalman due to test_sparse_csr.py::TestSparseCompressedCPU::test_consistency_SparseCSR_asinh_cpu_complex64 [GH job link](https://github.com/pytorch/pytorch/actions/runs/17329155095/job/49201551217) [HUD commit link](93c5112f46) ([comment](https://github.com/pytorch/pytorch/pull/160839#issuecomment-3238093296))
This commit is contained in:
PyTorch MergeBot
2025-08-29 19:55:39 +00:00
parent bf6aaba0f7
commit f6368e934e
10 changed files with 85 additions and 471 deletions

View File

@ -417,7 +417,6 @@ TORCH_IMPL_FUNC(sgn_out_mps)(const Tensor& self, const Tensor& output) {
Tensor& conj_physical_out_mps(const Tensor& self, Tensor& result) {
TORCH_CHECK(self.is_complex());
TORCH_CHECK(self.dtype() != at::kComplexDouble);
mps::unary_op(self, result, "conj", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
return [mpsGraph conjugateWithTensor:inputTensor name:nil];
});

View File

@ -340,8 +340,8 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: abs
SparseCPU, SparseCUDA, SparseMPS: abs_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr
SparseCPU, SparseCUDA: abs_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs
tags: [core, pointwise]
@ -350,16 +350,16 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: abs_
SparseCPU, SparseCUDA, SparseMPS: abs_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr_
SparseCPU, SparseCUDA: abs_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_abs_
- func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA, MPS, MTIA: abs_out
SparseCPU, SparseCUDA, SparseMPS: abs_sparse_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: abs_sparse_csr_out
SparseCPU, SparseCUDA: abs_sparse_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: abs_sparse_csr_out
tags: pointwise
# Note [Adding an alias]
@ -476,7 +476,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: _conj_physical
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: conj_physical_sparse_csr
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr
autogen: _conj_physical.out
- func: conj_physical(Tensor self) -> Tensor
@ -487,8 +487,8 @@
dispatch:
CPU, CUDA: conj_physical_out
MPS: conj_physical_out_mps
SparseCPU, SparseCUDA, SparseMPS: conj_physical_out_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: conj_physical_sparse_csr_out
SparseCPU, SparseCUDA: conj_physical_out_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: conj_physical_sparse_csr_out
tags: pointwise
- func: conj_physical_(Tensor(a!) self) -> Tensor(a!)
@ -554,7 +554,7 @@
structured_delegate: add.out
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: add_sparse
SparseCPU, SparseCUDA, SparseMeta: add_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr
MkldnnCPU: mkldnn_add
ZeroTensor: add_zerotensor
@ -566,7 +566,7 @@
variants: method
structured_delegate: add.out
dispatch:
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: add_sparse_
SparseCPU, SparseCUDA, SparseMeta: add_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: add_sparse_csr_
MkldnnCPU: mkldnn_add_
NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: NestedTensor_add__Tensor
@ -582,7 +582,6 @@
dispatch:
SparseCPU, SparseMeta: add_out_sparse_cpu
SparseCUDA: add_out_sparse_cuda
SparseMPS: add_out_sparse_mps
SparseCsrCPU, SparseCsrMeta: add_out_sparse_compressed_cpu
SparseCsrCUDA: add_out_sparse_compressed_cuda
MkldnnCPU: mkldnn_add_out
@ -2407,7 +2406,7 @@
MPS: empty_mps
Meta: empty_meta_symint
MkldnnCPU: empty_mkldnn
SparseCPU, SparseCUDA, SparseMPS: empty_sparse
SparseCPU, SparseCUDA: empty_sparse
SparseMeta: empty_sparse_symint
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
SparseCsrMeta: empty_sparse_compressed_symint
@ -6386,8 +6385,8 @@
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: trunc_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr
SparseCPU, SparseCUDA: trunc_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr
tags: [core, pointwise]
- func: trunc_(Tensor(a!) self) -> Tensor(a!)
@ -6395,8 +6394,8 @@
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: trunc_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr_
SparseCPU, SparseCUDA: trunc_sparse_
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_
tags: pointwise
- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
@ -6405,8 +6404,8 @@
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA, MPS: trunc_out
SparseCPU, SparseCUDA, SparseMPS: trunc_sparse_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: trunc_sparse_csr_out
SparseCPU, SparseCUDA: trunc_sparse_out
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: trunc_sparse_csr_out
tags: pointwise
# Alias for trunc
@ -7368,8 +7367,8 @@
- func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMPS: sparse_to_dense
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: sparse_compressed_to_dense
SparseCPU, SparseCUDA: sparse_to_dense
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_compressed_to_dense
MkldnnCPU: mkldnn_to_dense
autogen: _to_dense.out
@ -7395,8 +7394,8 @@
- func: dense_dim(Tensor self) -> int
variants: method
dispatch:
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: dense_dim_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: dense_dim_sparse_csr
SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: dense_dim_sparse_csr
CompositeExplicitAutograd: dense_dim_default
device_check: NoCheck
device_guard: False
@ -7529,7 +7528,7 @@
device_check: NoCheck # Allows copy into different device
variants: function
dispatch:
SparseCPU, SparseCUDA, SparseMPS, SparseMeta: copy_sparse_
SparseCPU, SparseCUDA, SparseMeta: copy_sparse_
autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out
# By adding the AutogradNestedTensor this makes this function CompositeImplicit-like for nested tensors

View File

@ -1,73 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/sparse/SparseStubs.h>
#include <ATen/native/sparse/FlattenIndicesCommon.h>
#include <ATen/ExpandUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_coalesce_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/zeros_native.h>
#endif
namespace at::native {
namespace {
using namespace mps;
using namespace at::sparse;
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/FlattenIndices_metallib.h>
#endif
Tensor flatten_indices_mps(const Tensor& indices, IntArrayRef size) {
TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D");
TORCH_CHECK(static_cast<size_t>(indices.size(0)) == size.size(),
"flatten_indices: indices.size(0) must equal size.size()");
const int64_t sparse_dim = indices.size(0);
const int64_t nnz = indices.size(1);
if (nnz == 0) {
return at::empty({0}, indices.options().dtype(kLong));
}
// Row-major multipliers for flattening: mul[d] = prod_{j>d}(size[j])
std::vector<int64_t> row_muls(sparse_dim);
row_muls[sparse_dim - 1] = 1;
for (int64_t i = sparse_dim - 2; i >= 0; --i) {
row_muls[i] = row_muls[i + 1] * size[i + 1];
}
auto flat_indices = at::empty({nnz}, indices.options().dtype(kLong));
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder,
indices,
row_muls,
flat_indices,
static_cast<uint>(sparse_dim),
indices.strides()
);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
return flat_indices;
}
} // namespace
REGISTER_MPS_DISPATCH(flatten_indices_stub, &flatten_indices_mps)
} // namespace at::native

View File

@ -20,9 +20,46 @@ using namespace at::sparse;
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/Coalesce_metallib.h>
#include <ATen/native/mps/Sparse_metallib.h>
#endif
static Tensor flatten_indices(const Tensor& indices, IntArrayRef size) {
TORCH_CHECK(indices.dim() == 2, "flatten_indices: indices must be 2D");
TORCH_CHECK(static_cast<size_t>(indices.size(0)) == size.size(),
"flatten_indices: indices.size(0) must equal size.size()");
int64_t sparse_dim = indices.size(0);
int64_t nnz = indices.size(1);
if (nnz == 0) {
return at::empty({0}, indices.options().dtype(kLong));
}
std::vector<int64_t> strides(sparse_dim);
strides[sparse_dim - 1] = 1;
for (int64_t i = sparse_dim - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * size[i + 1];
}
Tensor flat_indices = at::empty({nnz}, indices.options().dtype(kLong));
auto stream = getCurrentMPSStream();
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pipeline = lib.getPipelineStateForFunc("flatten_indices_kernel");
auto encoder = stream->commandEncoder();
[encoder setComputePipelineState:pipeline];
mtl_setArgs(encoder, indices, strides, flat_indices, sparse_dim, nnz);
mtl_dispatch1DJob(encoder, pipeline, nnz);
}
});
return flat_indices;
}
static Tensor compute_output_positions(const Tensor& is_unique) {
int64_t nnz = is_unique.size(0);

View File

@ -1,169 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/SparseTensorUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_coalesce_native.h>
#include <ATen/ops/_sparse_coo_tensor_unsafe_native.h>
#include <ATen/ops/cat.h>
#include <ATen/ops/add_native.h>
#include <ATen/ops/empty_native.h>
#include <ATen/ops/zeros_native.h>
#include <ATen/ops/result_type.h>
#include <ATen/ops/copy_sparse_to_sparse.h>
#include <ATen/ops/mul.h>
#endif
namespace at::native {
using namespace at::sparse;
Tensor& add_out_dense_sparse_mps(Tensor& out, const Tensor& dense, const SparseTensor& sparse, const Scalar& alpha);
Tensor& add_out_dense_sparse_mps(
Tensor& out,
const Tensor& dense,
const SparseTensor& sparse,
const Scalar& alpha) {
TORCH_CHECK(dense.is_mps(), "add: expected 'self' to be an MPS tensor, got ", dense.device());
TORCH_CHECK(sparse.is_mps(), "add: expected 'other' to be an MPS tensor, got ", sparse.device());
TORCH_CHECK(out.is_mps(), "add: expected 'out' to be an MPS tensor, got ", out.device());
TORCH_CHECK(dense.sizes().equals(sparse.sizes()),
"add: expected 'self' and 'other' to have same size, but self has size ",
dense.sizes(), " while other has size ", sparse.sizes(),
" (FYI: dense-sparse addition does not currently support broadcasting)");
const int64_t nnz = sparse._nnz();
if (nnz == 0) {
out.resize_as_(dense);
out.copy_(dense);
return out;
}
auto commonDtype = at::result_type(dense, sparse);
TORCH_CHECK(canCast(commonDtype, out.scalar_type()),
"Can't convert result type ", commonDtype, " to output ", out.scalar_type());
Tensor r;
const bool need_separate_buffer = out.is_same(dense) || (out.scalar_type() != commonDtype);
if (need_separate_buffer) {
r = at::empty(dense.sizes(), out.options().dtype(commonDtype));
} else {
r = out;
r.resize_as_(dense);
}
Tensor dense_buffer = dense.to(commonDtype);
if (!r.is_same(dense_buffer)) {
r.copy_(dense_buffer);
}
Tensor indices = sparse._indices();
Tensor values = sparse._values().to(commonDtype);
if (values.numel() == 0) {
if (!out.is_same(r)) {
out.resize_as_(dense);
out.copy_(r);
}
return out;
}
const int64_t nDim = r.dim();
const int64_t nDimI = sparse.sparse_dim();
TORCH_CHECK(nDimI >= 0 && nDimI <= nDim,
"Invalid sparse_dim=", nDimI, " for dense tensor of dim ", nDim);
Tensor indices1D = at::sparse::flatten_indices(indices, sparse.sizes()).contiguous();
int64_t view_rows = 1;
int64_t view_cols = 1;
for (int64_t i = 0; i < nDimI; i++) {
view_rows *= r.size(i);
}
for (int64_t i = nDimI; i < nDim; i++) {
view_cols *= r.size(i);
}
if (view_cols == 1) {
Tensor r_flat = r.reshape({view_rows});
Tensor values_1d = values.reshape({nnz});
r_flat.index_add_(0, indices1D, values_1d, alpha);
} else {
Tensor r_view = r.view({view_rows, view_cols});
Tensor values_2d = values.reshape({nnz, view_cols});
r_view.index_add_(0, indices1D, values_2d, alpha);
}
if (!out.is_same(r)) {
out.resize_as_(dense);
out.copy_(r);
}
return out;
}
SparseTensor& add_out_sparse_mps(const SparseTensor& self,
const SparseTensor& other,
const Scalar& alpha,
SparseTensor& out) {
TORCH_CHECK(other.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
TORCH_CHECK(self.is_mps(), "add: expected 'self' to be MPS, but got ", self.device());
TORCH_CHECK(other.is_mps(), "add: expected 'other' to be MPS, but got ", other.device());
TORCH_CHECK(out.is_mps(), "add: expected 'out' to be MPS, but got ", out.device());
if (!self.is_sparse()) {
return add_out_dense_sparse_mps(out, self, other, alpha);
}
auto commonDtype = at::result_type(self, other);
TORCH_CHECK(canCast(commonDtype, out.scalar_type()),
"Can't convert result type ", commonDtype, " to output ", out.scalar_type());
TORCH_CHECK(self.sizes().equals(other.sizes()),
"add: expected 'self' and 'other' to have same size, but ", self.sizes(), " != ", other.sizes());
TORCH_CHECK(is_same_density(self, other),
"add: expected 'self' and 'other' to have same density, but 'self' has ",
self.sparse_dim(), " sparse dimensions while 'other' has ", other.sparse_dim(), " sparse dimensions");
if (other._nnz() == 0) {
out.resize_as_(self);
Tensor vals = self._values();
if (vals.scalar_type() != out.scalar_type()) {
vals = vals.to(out.scalar_type());
}
alias_into_sparse(out, self._indices(), vals);
out._coalesced_(self.is_coalesced());
return out;
}
Tensor t_indices_ = self._indices();
Tensor s_indices_ = other._indices();
Tensor t_values_ = self._values().to(commonDtype);
Tensor s_values_ = other._values().to(commonDtype);
if (!alpha.isIntegral(false) || alpha.to<double>() != 1.0) {
s_values_ = at::mul(s_values_, alpha);
}
Tensor r_indices_ = at::cat({t_indices_, s_indices_}, 1);
Tensor r_values_ = at::cat({t_values_, s_values_ }, 0);
SparseTensor tmp = empty({0}, out.options().dtype(commonDtype));
tmp.resize_as_(other);
alias_into_sparse(tmp, r_indices_, r_values_);
tmp = _coalesce_sparse_mps(tmp);
out.resize_as_(other);
Tensor out_vals = tmp._values();
if (out.scalar_type() != commonDtype) {
out_vals = out_vals.to(out.scalar_type());
}
alias_into_sparse(out, tmp._indices(), out_vals);
out._coalesced_(tmp.is_coalesced());
return out;
}
} // namespace at::native

View File

@ -1,19 +0,0 @@
#include <metal_stdlib>
using namespace metal;
kernel void flatten_indices_kernel(
device const long* indices [[ buffer(0) ]],
device const long* row_muls [[ buffer(1) ]],
device long* flat_indices [[ buffer(2) ]],
constant uint& sparse_dim [[ buffer(3) ]],
constant long2& idx_strides [[ buffer(4) ]],
uint gid [[ thread_position_in_grid ]]) {
long flat = 0;
for (uint d = 0; d < sparse_dim; ++d) {
long off = (long)d * idx_strides.x + (long)gid * idx_strides.y;
long v = indices[off];
flat += v * row_muls[d];
}
flat_indices[gid] = flat;
}

View File

@ -2,6 +2,19 @@
#include <metal_stdlib>
using namespace metal;
kernel void flatten_indices_kernel(
device const int64_t* indices [[buffer(0)]],
device const int64_t* strides [[buffer(1)]],
device int64_t* flat_indices [[buffer(2)]],
constant uint& sparse_dim [[buffer(3)]],
constant uint& nnz [[buffer(4)]],
uint gid [[thread_position_in_grid]]) {
int64_t flat_idx = 0;
for (uint d = 0; d < sparse_dim; d++) {
flat_idx += indices[d * nnz + gid] * strides[d];
}
flat_indices[gid] = flat_idx;
}
kernel void compute_output_positions_kernel(
device const bool* is_unique [[buffer(0)]],

View File

@ -12868,100 +12868,6 @@ class TestSparseMPS(TestCaseMPS):
self.assertEqual(coalesced_mps._indices().cpu(), coalesced_cpu._indices())
self.assertEqual(coalesced_mps._values().cpu(), coalesced_cpu._values())
def test_sparse_add(self):
# Basic dense + sparse add
dense_mps = torch.zeros((2, 3), device="mps", dtype=torch.float32)
sparse_mps = self._get_basic_sparse_coo(device="mps")
dense_cpu = dense_mps.cpu()
sparse_cpu = torch.sparse_coo_tensor(
sparse_mps._indices().cpu(), sparse_mps._values().cpu(), sparse_mps.size(), device="cpu"
)
res_mps = torch.add(dense_mps, sparse_mps)
res_cpu = torch.add(dense_cpu, sparse_cpu)
self.assertEqual(res_mps.cpu(), res_cpu)
# alpha scaling (integral alpha)
res_mps = torch.add(dense_mps, sparse_mps, alpha=2)
res_cpu = torch.add(dense_cpu, sparse_cpu, alpha=2)
self.assertEqual(res_mps.cpu(), res_cpu)
# alpha scaling (float alpha) with random dense
dense2_mps = torch.randn((2, 3), device="mps", dtype=torch.float32)
dense2_cpu = dense2_mps.cpu()
res_mps = torch.add(dense2_mps, sparse_mps, alpha=0.5)
res_cpu = torch.add(dense2_cpu, sparse_cpu, alpha=0.5)
self.assertEqual(res_mps.cpu(), res_cpu)
# nnz == 0 fast-path
empty_indices_mps = torch.zeros((2, 0), dtype=torch.int64, device="mps")
empty_values_mps = torch.tensor([], dtype=torch.float32, device="mps")
empty_sparse_mps = torch.sparse_coo_tensor(empty_indices_mps, empty_values_mps, (2, 3), device="mps")
empty_indices_cpu = empty_indices_mps.cpu()
empty_values_cpu = empty_values_mps.cpu()
empty_sparse_cpu = torch.sparse_coo_tensor(empty_indices_cpu, empty_values_cpu, (2, 3), device="cpu")
res_mps = torch.add(dense2_mps, empty_sparse_mps)
res_cpu = torch.add(dense2_cpu, empty_sparse_cpu)
self.assertEqual(res_mps.cpu(), res_cpu)
# 3D case to exercise view_cols > 1 path (values are 2D)
indices3_mps = torch.tensor([[0, 1], [2, 0]], dtype=torch.int64, device="mps")
values3_mps = torch.tensor([[1., 2., 3., 4.], [5., 6., 7., 8.]], dtype=torch.float32, device="mps")
size3 = (2, 3, 4)
sp3_mps = torch.sparse_coo_tensor(indices3_mps, values3_mps, size3, device="mps")
dense3_mps = torch.randn(size3, device="mps", dtype=torch.float32)
indices3_cpu = indices3_mps.cpu()
values3_cpu = values3_mps.cpu()
sp3_cpu = torch.sparse_coo_tensor(indices3_cpu, values3_cpu, size3, device="cpu")
dense3_cpu = dense3_mps.cpu()
res_mps = torch.add(dense3_mps, sp3_mps, alpha=1.0)
res_cpu = torch.add(dense3_cpu, sp3_cpu, alpha=1.0)
self.assertEqual(res_mps.cpu(), res_cpu)
# dtype promotion: dense float32 + sparse float16
sparse_f16_mps = torch.sparse_coo_tensor(
sparse_mps._indices(),
sparse_mps._values().to(torch.float16),
sparse_mps.size(),
device="mps",
)
sparse_f16_cpu = torch.sparse_coo_tensor(
sparse_f16_mps._indices().cpu(),
sparse_f16_mps._values().cpu(),
sparse_f16_mps.size(),
device="cpu",
)
res_mps = torch.add(dense2_mps, sparse_f16_mps, alpha=0.25)
res_cpu = torch.add(dense2_cpu, sparse_f16_cpu, alpha=0.25)
self.assertEqual(res_mps.cpu(), res_cpu)
# broadcasting not supported: mismatched size should error
bad_sparse_mps = torch.sparse_coo_tensor(
sparse_mps._indices(), sparse_mps._values(), (2, 4), device="mps"
)
with self.assertRaisesRegex(RuntimeError, "same size"):
torch.add(dense_mps, bad_sparse_mps)
# sparse + sparse with overlap (tests concatenation + coalesce + alpha)
s1_idx = torch.tensor([[0, 0, 1], [0, 0, 2]], dtype=torch.int64)
s1_val = torch.tensor([1., 2., 3.], dtype=torch.float32)
s2_idx = torch.tensor([[0, 1, 1], [0, 2, 2]], dtype=torch.int64)
s2_val = torch.tensor([4., 5., 6.], dtype=torch.float32)
s1_mps = torch.sparse_coo_tensor(s1_idx.to("mps"), s1_val.to("mps"), (2, 3), device="mps")
s2_mps = torch.sparse_coo_tensor(s2_idx.to("mps"), s2_val.to("mps"), (2, 3), device="mps")
s1_cpu = torch.sparse_coo_tensor(s1_idx, s1_val, (2, 3), device="cpu")
s2_cpu = torch.sparse_coo_tensor(s2_idx, s2_val, (2, 3), device="cpu")
sp_res_mps = torch.add(s1_mps, s2_mps, alpha=2.0).coalesce()
sp_res_cpu = torch.add(s1_cpu, s2_cpu, alpha=2.0).coalesce()
self.assertEqual(sp_res_mps.cpu(), sp_res_cpu)
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the

View File

@ -14,7 +14,6 @@ from torch.testing._internal.common_utils import TestCase, run_tests, do_test_dt
parametrize, subtest, is_coalesced_indices, suppress_warnings, instantiate_parametrized_tests, \
skipIfCrossRef
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_mps import mps_ops_modifier
from numbers import Number
from typing import Any
from packaging import version
@ -43,6 +42,7 @@ def _op_supports_any_sparse(op):
or op.supports_sparse_bsc)
reduction_ops_with_sparse_support = [
op for op in reduction_ops if 'masked.' not in op.name and
_op_supports_any_sparse(op) and not isinstance(op, ReductionPythonRefInfo)]
@ -4126,7 +4126,7 @@ def _sparse_to_dense(tensor):
return tensor.to(torch.int8).to_dense().to(torch.bool)
_sparse_unary_ops = ops(mps_ops_modifier(sparse_unary_ufuncs, sparse=True), dtypes=OpDTypes.supported,
_sparse_unary_ops = ops(sparse_unary_ufuncs, dtypes=OpDTypes.supported,
allowed_dtypes=all_types_and_complex())
class TestSparseUnaryUfuncs(TestCase):
exact_dtype = True
@ -4178,8 +4178,8 @@ class TestSparseUnaryUfuncs(TestCase):
@_sparse_unary_ops
def test_sparse_zero_dims(self, device, dtype, op):
# test 0x0 sparse_coo_tensor
indices = torch.empty(2, 0, dtype=torch.int64, device=device)
values = torch.empty(0, dtype=dtype, device=device)
indices = torch.empty(2, 0, dtype=torch.int64)
values = torch.empty(0, dtype=dtype)
sparse_0x0 = torch.sparse_coo_tensor(indices, values, (0, 0))
expected = torch.sparse_coo_tensor(indices, op(values), (0, 0))
actual = op(sparse_0x0)
@ -5526,7 +5526,7 @@ class TestSparseAny(TestCase):
# e.g., TestSparseUnaryUfuncsCPU and TestSparseUnaryUfuncsCUDA
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), allow_mps=True, except_for='meta')
instantiate_device_type_tests(TestSparseUnaryUfuncs, globals(), except_for='meta')
instantiate_device_type_tests(TestSparseMaskedReductions, globals(), except_for='meta')

View File

@ -14,7 +14,6 @@ if torch.backends.mps.is_available():
ops: Sequence[OpInfo],
device_type: Optional[str] = None,
xfail_exclusion: Optional[list[str]] = None,
sparse: bool = False,
) -> Sequence[OpInfo]:
if xfail_exclusion is None:
xfail_exclusion = []
@ -295,7 +294,7 @@ if torch.backends.mps.is_available():
}
# Those ops are not expected to work
UNIMPLEMENTED_XFAILLIST: dict[str, Optional[list]] = {
UNIMPLEMENTED_XFAILLIST = {
# Failures due to lack of op implementation on MPS backend
"logspace": None,
"logspacetensor_overload": None,
@ -441,42 +440,6 @@ if torch.backends.mps.is_available():
torch.int8,
],
}
UNIMPLEMENTED_XFAILLIST_SPARSE: dict[str, Optional[list]] = {
"logspace": None,
"logspacetensor_overload": None,
"linalg.eig": None,
"linalg.eigvals": None,
"put": None,
"deg2rad": None,
"erf": None,
"expm1": None,
"floor": None,
"frac": None,
"isneginf": None,
"isposinf": None,
"log1p": None,
"nan_to_num": None,
"neg": None,
"rad2deg": None,
"round": None,
"sgn": None,
"sign": None,
"signbit": None,
"sin": None,
"sinh": None,
"sqrt": None,
"tan": None,
"tanh": None,
"asinh": None,
"asin": None,
"isnan": None,
"isinf": None,
"atan": None,
"atanh": None,
"ceil": None,
"relu": None,
"nn.functional.relu": None,
}
if MACOS_VERSION < 15.0:
UNIMPLEMENTED_XFAILLIST.update(
@ -485,10 +448,8 @@ if torch.backends.mps.is_available():
"nanquantile": None,
}
)
if sparse:
UNIMPLEMENTED_XFAILLIST.update(UNIMPLEMENTED_XFAILLIST_SPARSE)
UNDEFINED_XFAILLIST: dict[str, Optional[list]] = {
UNDEFINED_XFAILLIST = {
# Top 60 operators
# topk fails with duplicate indices
"topk": [
@ -565,7 +526,7 @@ if torch.backends.mps.is_available():
],
}
ON_MPS_XFAILLIST: dict[str, Optional[list]] = {
ON_MPS_XFAILLIST = {
# Failures due to lack of implementation of downstream functions on MPS backend
# TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented
"linalg.matrix_rank": None,
@ -629,45 +590,15 @@ if torch.backends.mps.is_available():
# precision types. So we have to skip these for now.
"grid_sampler_3d": [torch.float16, torch.bfloat16],
}
SKIPLIST_SPARSE = {
# Skipped due to test_sparse_zero_dims test in test_sparse.py which allocates empty tensor
# and does basically a no-op op(positive), which leads to unexpected success
"positive": [torch.complex128],
}
def addDecorator(
op: OpInfo, d: DecorateInfo, _device_type: Optional[str] = device_type
) -> None:
if _device_type is not None:
d.device_type = _device_type
def addDecorator(op: OpInfo, d: DecorateInfo) -> None:
if device_type is not None:
d.device_type = device_type
op.decorators = op.decorators + (d,)
for op in ops:
key = op.name + op.variant_test_name
addDecorator(
op,
DecorateInfo(
unittest.expectedFailure,
dtypes=[
torch.double,
torch.cdouble,
],
),
_device_type="mps",
)
if sparse and op.name in SKIPLIST_SPARSE:
addDecorator(
op,
DecorateInfo(
unittest.skip(
"Skipped due to MPS not supporting complex128 tensors"
),
dtypes=[
torch.complex128,
],
),
)
if key in EMPTY_OPS_SKIPLIST:
addDecorator(
op,
@ -689,7 +620,6 @@ if torch.backends.mps.is_available():
addDecorator(
op,
DecorateInfo(unittest.expectedFailure, dtypes=xfaillist[key]),
_device_type="mps",
)
if (
@ -875,12 +805,3 @@ if torch.backends.mps.is_available():
addDecorator(op, DecorateInfo(unittest.expectedFailure))
return ops
else:
def mps_ops_modifier(
ops: Sequence[OpInfo],
device_type: Optional[str] = None,
xfail_exclusion: Optional[list[str]] = None,
sparse: bool = False,
) -> Sequence[OpInfo]:
return ops